From 959c64e28ab1a85851fdef68c6488ecb685e81bf Mon Sep 17 00:00:00 2001 From: yuye-aws Date: Tue, 12 Mar 2024 16:51:13 +0800 Subject: [PATCH] adjust functions in chunker interface Signed-off-by: yuye-aws --- .../processor/TextChunkingProcessor.java | 7 +----- .../processor/chunker/Chunker.java | 20 +++------------- .../chunker/ChunkerParameterValidator.java | 3 +++ .../processor/chunker/DelimiterChunker.java | 13 ++++++---- .../chunker/FixedTokenLengthChunker.java | 4 ++-- .../chunker/DelimiterChunkerTests.java | 14 +++++------ .../chunker/FixedTokenLengthChunkerTests.java | 24 +++++++++---------- 7 files changed, 37 insertions(+), 48 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/TextChunkingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/TextChunkingProcessor.java index 8e46adcac..123d3d4ca 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/TextChunkingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/TextChunkingProcessor.java @@ -148,12 +148,7 @@ private boolean isListOfString(Object value) { private int chunkString(String content, List result, Map runTimeParameters, int chunkCount) { // chunk the content, return the updated chunkCount and add chunk passages to result - List contentResult; - if (chunker instanceof FixedTokenLengthChunker) { - contentResult = chunker.chunk(content, runTimeParameters); - } else { - contentResult = chunker.chunk(content); - } + List contentResult = chunker.chunk(content, runTimeParameters); chunkCount += contentResult.size(); if (maxChunkLimit != DEFAULT_MAX_CHUNK_LIMIT && chunkCount > maxChunkLimit) { throw new IllegalArgumentException( diff --git a/src/main/java/org/opensearch/neuralsearch/processor/chunker/Chunker.java b/src/main/java/org/opensearch/neuralsearch/processor/chunker/Chunker.java index 29a0539f2..8419c1d98 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/chunker/Chunker.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/chunker/Chunker.java @@ -4,8 +4,6 @@ */ package org.opensearch.neuralsearch.processor.chunker; -import com.google.common.collect.ImmutableList; - import java.util.Map; import java.util.List; @@ -16,22 +14,12 @@ public interface Chunker { /** - * Validate the parameters for chunking algorithm, + * Validate and parse the parameters for chunking algorithm, * will throw IllegalArgumentException when parameters are invalid * * @param parameters a map containing parameters for chunking algorithms */ - void validateParameters(Map parameters); - - /** - * Chunk the incoming string according to parameters and return chunked passages - * - * @param content input string - * @return Chunked passages - */ - default List chunk(String content) { - return ImmutableList.of(); - } + void validateAndParseParameters(Map parameters); /** * Chunk the incoming string according to parameters and return chunked passages @@ -40,7 +28,5 @@ default List chunk(String content) { * @param runtimeParameters a map containing runtime parameters for chunking algorithms * @return Chunked passages */ - default List chunk(String content, Map runtimeParameters) { - return ImmutableList.of(); - } + List chunk(String content, Map runtimeParameters); } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/chunker/ChunkerParameterValidator.java b/src/main/java/org/opensearch/neuralsearch/processor/chunker/ChunkerParameterValidator.java index b3f399074..368f2cbfe 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/chunker/ChunkerParameterValidator.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/chunker/ChunkerParameterValidator.java @@ -9,6 +9,9 @@ import java.util.Map; +/** + * Validate and parse the parameter for chunking algorithms + */ public class ChunkerParameterValidator { public static String validateStringParameters( diff --git a/src/main/java/org/opensearch/neuralsearch/processor/chunker/DelimiterChunker.java b/src/main/java/org/opensearch/neuralsearch/processor/chunker/DelimiterChunker.java index f426f2b37..aabe1d4ae 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/chunker/DelimiterChunker.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/chunker/DelimiterChunker.java @@ -16,7 +16,7 @@ public class DelimiterChunker implements Chunker { public DelimiterChunker(Map parameters) { - validateParameters(parameters); + validateAndParseParameters(parameters); } public static final String DELIMITER_FIELD = "delimiter"; @@ -33,12 +33,18 @@ public DelimiterChunker(Map parameters) { * @throws IllegalArgumentException If delimiter is not a string or empty */ @Override - public void validateParameters(Map parameters) { + public void validateAndParseParameters(Map parameters) { this.delimiter = validateStringParameters(parameters, DELIMITER_FIELD, DEFAULT_DELIMITER, false); } + /** + * Return the chunked passages for fixed token length algorithm + * + * @param content input string + * @param runtimeParameters a map for runtime parameters, but not needed by delimiter algorithm + */ @Override - public List chunk(String content) { + public List chunk(String content, Map runtimeParameters) { List chunkResult = new ArrayList<>(); int start = 0, end; int nextDelimiterPosition = content.indexOf(delimiter); @@ -56,5 +62,4 @@ public List chunk(String content) { return chunkResult; } - } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/chunker/FixedTokenLengthChunker.java b/src/main/java/org/opensearch/neuralsearch/processor/chunker/FixedTokenLengthChunker.java index acdd65aea..2968ec9f5 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/chunker/FixedTokenLengthChunker.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/chunker/FixedTokenLengthChunker.java @@ -44,7 +44,7 @@ public class FixedTokenLengthChunker implements Chunker { private final AnalysisRegistry analysisRegistry; public FixedTokenLengthChunker(Map parameters) { - validateParameters(parameters); + validateAndParseParameters(parameters); this.analysisRegistry = (AnalysisRegistry) parameters.get(ANALYSIS_REGISTRY_FIELD); } @@ -63,7 +63,7 @@ public FixedTokenLengthChunker(Map parameters) { * tokenizer should be string */ @Override - public void validateParameters(Map parameters) { + public void validateAndParseParameters(Map parameters) { this.tokenLimit = validatePositiveIntegerParameter(parameters, TOKEN_LIMIT_FIELD, DEFAULT_TOKEN_LIMIT); if (parameters.containsKey(OVERLAP_RATE_FIELD)) { String overlapRateString = parameters.get(OVERLAP_RATE_FIELD).toString(); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/chunker/DelimiterChunkerTests.java b/src/test/java/org/opensearch/neuralsearch/processor/chunker/DelimiterChunkerTests.java index fec187b9e..37969a51b 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/chunker/DelimiterChunkerTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/chunker/DelimiterChunkerTests.java @@ -33,7 +33,7 @@ public void testChunkerWithDelimiterFieldNoString() { public void testChunker() { DelimiterChunker chunker = new DelimiterChunker(Map.of(DELIMITER_FIELD, "\n")); String content = "a\nb\nc\nd"; - List chunkResult = chunker.chunk(content); + List chunkResult = chunker.chunk(content, Map.of()); assertEquals(List.of("a\n", "b\n", "c\n", "d"), chunkResult); } @@ -41,42 +41,42 @@ public void testChunkerWithDefaultDelimiter() { // default delimiter is \n\n DelimiterChunker chunker = new DelimiterChunker(Map.of()); String content = "a.b\n\nc.d"; - List chunkResult = chunker.chunk(content); + List chunkResult = chunker.chunk(content, Map.of()); assertEquals(List.of("a.b\n\n", "c.d"), chunkResult); } public void testChunkerWithDelimiterEnd() { DelimiterChunker chunker = new DelimiterChunker(Map.of(DELIMITER_FIELD, "\n")); String content = "a\nb\nc\nd\n"; - List chunkResult = chunker.chunk(content); + List chunkResult = chunker.chunk(content, Map.of()); assertEquals(List.of("a\n", "b\n", "c\n", "d\n"), chunkResult); } public void testChunkerWithOnlyDelimiter() { DelimiterChunker chunker = new DelimiterChunker(Map.of(DELIMITER_FIELD, "\n")); String content = "\n"; - List chunkResult = chunker.chunk(content); + List chunkResult = chunker.chunk(content, Map.of()); assertEquals(List.of("\n"), chunkResult); } public void testChunkerWithAllDelimiters() { DelimiterChunker chunker = new DelimiterChunker(Map.of(DELIMITER_FIELD, "\n")); String content = "\n\n\n"; - List chunkResult = chunker.chunk(content); + List chunkResult = chunker.chunk(content, Map.of()); assertEquals(List.of("\n", "\n", "\n"), chunkResult); } public void testChunkerWithDifferentDelimiters() { DelimiterChunker chunker = new DelimiterChunker(Map.of(DELIMITER_FIELD, ".")); String content = "a.b.cc.d."; - List chunkResult = chunker.chunk(content); + List chunkResult = chunker.chunk(content, Map.of()); assertEquals(List.of("a.", "b.", "cc.", "d."), chunkResult); } public void testChunkerWithStringDelimiter() { DelimiterChunker chunker = new DelimiterChunker(Map.of(DELIMITER_FIELD, "\n\n")); String content = "\n\na\n\n\n"; - List chunkResult = chunker.chunk(content); + List chunkResult = chunker.chunk(content, Map.of()); assertEquals(List.of("\n\n", "a\n\n", "\n"), chunkResult); } } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/chunker/FixedTokenLengthChunkerTests.java b/src/test/java/org/opensearch/neuralsearch/processor/chunker/FixedTokenLengthChunkerTests.java index d9934c184..49b633ced 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/chunker/FixedTokenLengthChunkerTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/chunker/FixedTokenLengthChunkerTests.java @@ -61,16 +61,16 @@ public Map> getTokeniz return new FixedTokenLengthChunker(nonRuntimeParameters); } - public void testValidateParameters_whenNoParams_thenSuccessful() { - fixedTokenLengthChunker.validateParameters(Map.of()); + public void testValidateAndParseParameters_whenNoParams_thenSuccessful() { + fixedTokenLengthChunker.validateAndParseParameters(Map.of()); } - public void testValidateParameters_whenIllegalTokenLimitType_thenFail() { + public void testValidateAndParseParameters_whenIllegalTokenLimitType_thenFail() { Map parameters = new HashMap<>(); parameters.put(TOKEN_LIMIT_FIELD, "invalid token limit"); IllegalArgumentException illegalArgumentException = assertThrows( IllegalArgumentException.class, - () -> fixedTokenLengthChunker.validateParameters(parameters) + () -> fixedTokenLengthChunker.validateAndParseParameters(parameters) ); assertEquals( "fixed length parameter [" + TOKEN_LIMIT_FIELD + "] cannot be cast to [" + Number.class.getName() + "]", @@ -78,22 +78,22 @@ public void testValidateParameters_whenIllegalTokenLimitType_thenFail() { ); } - public void testValidateParameters_whenIllegalTokenLimitValue_thenFail() { + public void testValidateAndParseParameters_whenIllegalTokenLimitValue_thenFail() { Map parameters = new HashMap<>(); parameters.put(TOKEN_LIMIT_FIELD, -1); IllegalArgumentException illegalArgumentException = assertThrows( IllegalArgumentException.class, - () -> fixedTokenLengthChunker.validateParameters(parameters) + () -> fixedTokenLengthChunker.validateAndParseParameters(parameters) ); assertEquals("fixed length parameter [" + TOKEN_LIMIT_FIELD + "] must be positive", illegalArgumentException.getMessage()); } - public void testValidateParameters_whenIllegalOverlapRateType_thenFail() { + public void testValidateAndParseParameters_whenIllegalOverlapRateType_thenFail() { Map parameters = new HashMap<>(); parameters.put(OVERLAP_RATE_FIELD, "invalid overlap rate"); IllegalArgumentException illegalArgumentException = assertThrows( IllegalArgumentException.class, - () -> fixedTokenLengthChunker.validateParameters(parameters) + () -> fixedTokenLengthChunker.validateAndParseParameters(parameters) ); assertEquals( "fixed length parameter [" + OVERLAP_RATE_FIELD + "] cannot be cast to [" + Number.class.getName() + "]", @@ -101,12 +101,12 @@ public void testValidateParameters_whenIllegalOverlapRateType_thenFail() { ); } - public void testValidateParameters_whenIllegalOverlapRateValue_thenFail() { + public void testValidateAndParseParameters_whenIllegalOverlapRateValue_thenFail() { Map parameters = new HashMap<>(); parameters.put(OVERLAP_RATE_FIELD, 0.6); IllegalArgumentException illegalArgumentException = assertThrows( IllegalArgumentException.class, - () -> fixedTokenLengthChunker.validateParameters(parameters) + () -> fixedTokenLengthChunker.validateAndParseParameters(parameters) ); assertEquals( "fixed length parameter [" + OVERLAP_RATE_FIELD + "] must be between 0 and 0.5", @@ -114,12 +114,12 @@ public void testValidateParameters_whenIllegalOverlapRateValue_thenFail() { ); } - public void testValidateParameters_whenIllegalTokenizerType_thenFail() { + public void testValidateAndParseParameters_whenIllegalTokenizerType_thenFail() { Map parameters = new HashMap<>(); parameters.put(TOKENIZER_FIELD, 111); IllegalArgumentException illegalArgumentException = assertThrows( IllegalArgumentException.class, - () -> fixedTokenLengthChunker.validateParameters(parameters) + () -> fixedTokenLengthChunker.validateAndParseParameters(parameters) ); assertEquals( "Chunker parameter [" + TOKENIZER_FIELD + "] cannot be cast to [" + String.class.getName() + "]",