Skip to content

Commit

Permalink
adjust functions in chunker interface
Browse files Browse the repository at this point in the history
Signed-off-by: yuye-aws <[email protected]>
  • Loading branch information
yuye-aws committed Mar 13, 2024
1 parent 3f85e63 commit 959c64e
Show file tree
Hide file tree
Showing 7 changed files with 37 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -148,12 +148,7 @@ private boolean isListOfString(Object value) {

private int chunkString(String content, List<String> result, Map<String, Object> runTimeParameters, int chunkCount) {
// chunk the content, return the updated chunkCount and add chunk passages to result
List<String> contentResult;
if (chunker instanceof FixedTokenLengthChunker) {
contentResult = chunker.chunk(content, runTimeParameters);
} else {
contentResult = chunker.chunk(content);
}
List<String> contentResult = chunker.chunk(content, runTimeParameters);
chunkCount += contentResult.size();
if (maxChunkLimit != DEFAULT_MAX_CHUNK_LIMIT && chunkCount > maxChunkLimit) {
throw new IllegalArgumentException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
*/
package org.opensearch.neuralsearch.processor.chunker;

import com.google.common.collect.ImmutableList;

import java.util.Map;
import java.util.List;

Expand All @@ -16,22 +14,12 @@
public interface Chunker {

/**
* Validate the parameters for chunking algorithm,
* Validate and parse the parameters for chunking algorithm,
* will throw IllegalArgumentException when parameters are invalid
*
* @param parameters a map containing parameters for chunking algorithms
*/
void validateParameters(Map<String, Object> parameters);

/**
* Chunk the incoming string according to parameters and return chunked passages
*
* @param content input string
* @return Chunked passages
*/
default List<String> chunk(String content) {
return ImmutableList.of();
}
void validateAndParseParameters(Map<String, Object> parameters);

/**
* Chunk the incoming string according to parameters and return chunked passages
Expand All @@ -40,7 +28,5 @@ default List<String> chunk(String content) {
* @param runtimeParameters a map containing runtime parameters for chunking algorithms
* @return Chunked passages
*/
default List<String> chunk(String content, Map<String, Object> runtimeParameters) {
return ImmutableList.of();
}
List<String> chunk(String content, Map<String, Object> runtimeParameters);
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@

import java.util.Map;

/**
* Validate and parse the parameter for chunking algorithms
*/
public class ChunkerParameterValidator {

public static String validateStringParameters(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
public class DelimiterChunker implements Chunker {

public DelimiterChunker(Map<String, Object> parameters) {
validateParameters(parameters);
validateAndParseParameters(parameters);
}

public static final String DELIMITER_FIELD = "delimiter";
Expand All @@ -33,12 +33,18 @@ public DelimiterChunker(Map<String, Object> parameters) {
* @throws IllegalArgumentException If delimiter is not a string or empty
*/
@Override
public void validateParameters(Map<String, Object> parameters) {
public void validateAndParseParameters(Map<String, Object> parameters) {
this.delimiter = validateStringParameters(parameters, DELIMITER_FIELD, DEFAULT_DELIMITER, false);
}

/**
* Return the chunked passages for fixed token length algorithm
*
* @param content input string
* @param runtimeParameters a map for runtime parameters, but not needed by delimiter algorithm
*/
@Override
public List<String> chunk(String content) {
public List<String> chunk(String content, Map<String, Object> runtimeParameters) {
List<String> chunkResult = new ArrayList<>();
int start = 0, end;
int nextDelimiterPosition = content.indexOf(delimiter);
Expand All @@ -56,5 +62,4 @@ public List<String> chunk(String content) {

return chunkResult;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ public class FixedTokenLengthChunker implements Chunker {
private final AnalysisRegistry analysisRegistry;

public FixedTokenLengthChunker(Map<String, Object> parameters) {
validateParameters(parameters);
validateAndParseParameters(parameters);
this.analysisRegistry = (AnalysisRegistry) parameters.get(ANALYSIS_REGISTRY_FIELD);
}

Expand All @@ -63,7 +63,7 @@ public FixedTokenLengthChunker(Map<String, Object> parameters) {
* tokenizer should be string
*/
@Override
public void validateParameters(Map<String, Object> parameters) {
public void validateAndParseParameters(Map<String, Object> parameters) {
this.tokenLimit = validatePositiveIntegerParameter(parameters, TOKEN_LIMIT_FIELD, DEFAULT_TOKEN_LIMIT);
if (parameters.containsKey(OVERLAP_RATE_FIELD)) {
String overlapRateString = parameters.get(OVERLAP_RATE_FIELD).toString();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,50 +33,50 @@ public void testChunkerWithDelimiterFieldNoString() {
public void testChunker() {
DelimiterChunker chunker = new DelimiterChunker(Map.of(DELIMITER_FIELD, "\n"));
String content = "a\nb\nc\nd";
List<String> chunkResult = chunker.chunk(content);
List<String> chunkResult = chunker.chunk(content, Map.of());
assertEquals(List.of("a\n", "b\n", "c\n", "d"), chunkResult);
}

public void testChunkerWithDefaultDelimiter() {
// default delimiter is \n\n
DelimiterChunker chunker = new DelimiterChunker(Map.of());
String content = "a.b\n\nc.d";
List<String> chunkResult = chunker.chunk(content);
List<String> chunkResult = chunker.chunk(content, Map.of());
assertEquals(List.of("a.b\n\n", "c.d"), chunkResult);
}

public void testChunkerWithDelimiterEnd() {
DelimiterChunker chunker = new DelimiterChunker(Map.of(DELIMITER_FIELD, "\n"));
String content = "a\nb\nc\nd\n";
List<String> chunkResult = chunker.chunk(content);
List<String> chunkResult = chunker.chunk(content, Map.of());
assertEquals(List.of("a\n", "b\n", "c\n", "d\n"), chunkResult);
}

public void testChunkerWithOnlyDelimiter() {
DelimiterChunker chunker = new DelimiterChunker(Map.of(DELIMITER_FIELD, "\n"));
String content = "\n";
List<String> chunkResult = chunker.chunk(content);
List<String> chunkResult = chunker.chunk(content, Map.of());
assertEquals(List.of("\n"), chunkResult);
}

public void testChunkerWithAllDelimiters() {
DelimiterChunker chunker = new DelimiterChunker(Map.of(DELIMITER_FIELD, "\n"));
String content = "\n\n\n";
List<String> chunkResult = chunker.chunk(content);
List<String> chunkResult = chunker.chunk(content, Map.of());
assertEquals(List.of("\n", "\n", "\n"), chunkResult);
}

public void testChunkerWithDifferentDelimiters() {
DelimiterChunker chunker = new DelimiterChunker(Map.of(DELIMITER_FIELD, "."));
String content = "a.b.cc.d.";
List<String> chunkResult = chunker.chunk(content);
List<String> chunkResult = chunker.chunk(content, Map.of());
assertEquals(List.of("a.", "b.", "cc.", "d."), chunkResult);
}

public void testChunkerWithStringDelimiter() {
DelimiterChunker chunker = new DelimiterChunker(Map.of(DELIMITER_FIELD, "\n\n"));
String content = "\n\na\n\n\n";
List<String> chunkResult = chunker.chunk(content);
List<String> chunkResult = chunker.chunk(content, Map.of());
assertEquals(List.of("\n\n", "a\n\n", "\n"), chunkResult);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -61,65 +61,65 @@ public Map<String, AnalysisModule.AnalysisProvider<TokenizerFactory>> getTokeniz
return new FixedTokenLengthChunker(nonRuntimeParameters);
}

public void testValidateParameters_whenNoParams_thenSuccessful() {
fixedTokenLengthChunker.validateParameters(Map.of());
public void testValidateAndParseParameters_whenNoParams_thenSuccessful() {
fixedTokenLengthChunker.validateAndParseParameters(Map.of());
}

public void testValidateParameters_whenIllegalTokenLimitType_thenFail() {
public void testValidateAndParseParameters_whenIllegalTokenLimitType_thenFail() {
Map<String, Object> parameters = new HashMap<>();
parameters.put(TOKEN_LIMIT_FIELD, "invalid token limit");
IllegalArgumentException illegalArgumentException = assertThrows(
IllegalArgumentException.class,
() -> fixedTokenLengthChunker.validateParameters(parameters)
() -> fixedTokenLengthChunker.validateAndParseParameters(parameters)
);
assertEquals(
"fixed length parameter [" + TOKEN_LIMIT_FIELD + "] cannot be cast to [" + Number.class.getName() + "]",
illegalArgumentException.getMessage()
);
}

public void testValidateParameters_whenIllegalTokenLimitValue_thenFail() {
public void testValidateAndParseParameters_whenIllegalTokenLimitValue_thenFail() {
Map<String, Object> parameters = new HashMap<>();
parameters.put(TOKEN_LIMIT_FIELD, -1);
IllegalArgumentException illegalArgumentException = assertThrows(
IllegalArgumentException.class,
() -> fixedTokenLengthChunker.validateParameters(parameters)
() -> fixedTokenLengthChunker.validateAndParseParameters(parameters)
);
assertEquals("fixed length parameter [" + TOKEN_LIMIT_FIELD + "] must be positive", illegalArgumentException.getMessage());
}

public void testValidateParameters_whenIllegalOverlapRateType_thenFail() {
public void testValidateAndParseParameters_whenIllegalOverlapRateType_thenFail() {
Map<String, Object> parameters = new HashMap<>();
parameters.put(OVERLAP_RATE_FIELD, "invalid overlap rate");
IllegalArgumentException illegalArgumentException = assertThrows(
IllegalArgumentException.class,
() -> fixedTokenLengthChunker.validateParameters(parameters)
() -> fixedTokenLengthChunker.validateAndParseParameters(parameters)
);
assertEquals(
"fixed length parameter [" + OVERLAP_RATE_FIELD + "] cannot be cast to [" + Number.class.getName() + "]",
illegalArgumentException.getMessage()
);
}

public void testValidateParameters_whenIllegalOverlapRateValue_thenFail() {
public void testValidateAndParseParameters_whenIllegalOverlapRateValue_thenFail() {
Map<String, Object> parameters = new HashMap<>();
parameters.put(OVERLAP_RATE_FIELD, 0.6);
IllegalArgumentException illegalArgumentException = assertThrows(
IllegalArgumentException.class,
() -> fixedTokenLengthChunker.validateParameters(parameters)
() -> fixedTokenLengthChunker.validateAndParseParameters(parameters)
);
assertEquals(
"fixed length parameter [" + OVERLAP_RATE_FIELD + "] must be between 0 and 0.5",
illegalArgumentException.getMessage()
);
}

public void testValidateParameters_whenIllegalTokenizerType_thenFail() {
public void testValidateAndParseParameters_whenIllegalTokenizerType_thenFail() {
Map<String, Object> parameters = new HashMap<>();
parameters.put(TOKENIZER_FIELD, 111);
IllegalArgumentException illegalArgumentException = assertThrows(
IllegalArgumentException.class,
() -> fixedTokenLengthChunker.validateParameters(parameters)
() -> fixedTokenLengthChunker.validateAndParseParameters(parameters)
);
assertEquals(
"Chunker parameter [" + TOKENIZER_FIELD + "] cannot be cast to [" + String.class.getName() + "]",
Expand Down

0 comments on commit 959c64e

Please sign in to comment.