Skip to content

Commit

Permalink
move algorithm name definition to concrete chunker class
Browse files Browse the repository at this point in the history
Signed-off-by: yuye-aws <[email protected]>
  • Loading branch information
yuye-aws committed Mar 13, 2024
1 parent 959c64e commit 906cf73
Show file tree
Hide file tree
Showing 7 changed files with 23 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@
import org.opensearch.index.mapper.IndexFieldMapper;
import org.opensearch.neuralsearch.processor.chunker.FixedTokenLengthChunker;

import static org.opensearch.neuralsearch.processor.chunker.ChunkerFactory.FIXED_TOKEN_LENGTH_ALGORITHM;

/**
* This processor is used for chunking user input data and chunked data could be used for downstream embedding processor,
* algorithm can be used to indicate chunking algorithm and parameters,
Expand Down Expand Up @@ -111,7 +109,7 @@ private void validateAndParseAlgorithmMap(Map<String, Object> algorithmMap) {
);
}
Map<String, Object> chunkerParameters = (Map<String, Object>) algorithmValue;
if (Objects.equals(algorithmKey, FIXED_TOKEN_LENGTH_ALGORITHM)) {
if (Objects.equals(algorithmKey, FixedTokenLengthChunker.ALGORITHM_NAME)) {
chunkerParameters.put(FixedTokenLengthChunker.ANALYSIS_REGISTRY_FIELD, analysisRegistry);
}
this.chunker = ChunkerFactory.create(algorithmKey, chunkerParameters);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,11 @@
*/
public class ChunkerFactory {

public static final String FIXED_TOKEN_LENGTH_ALGORITHM = "fixed_token_length";
public static final String DELIMITER_ALGORITHM = "delimiter";

public static Chunker create(String type, Map<String, Object> parameters) {
switch (type) {
case FIXED_TOKEN_LENGTH_ALGORITHM:
case FixedTokenLengthChunker.ALGORITHM_NAME:
return new FixedTokenLengthChunker(parameters);
case DELIMITER_ALGORITHM:
case DelimiterChunker.ALGORITHM_NAME:
return new DelimiterChunker(parameters);
default:
throw new IllegalArgumentException(
Expand All @@ -29,6 +26,6 @@ public static Chunker create(String type, Map<String, Object> parameters) {
}

public static Set<String> getAllChunkers() {
return Set.of(FIXED_TOKEN_LENGTH_ALGORITHM, DELIMITER_ALGORITHM);
return Set.of(FixedTokenLengthChunker.ALGORITHM_NAME, DelimiterChunker.ALGORITHM_NAME);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,16 @@
*/
public class DelimiterChunker implements Chunker {

public DelimiterChunker(Map<String, Object> parameters) {
validateAndParseParameters(parameters);
}

public static final String ALGORITHM_NAME = "delimiter";
public static final String DELIMITER_FIELD = "delimiter";

public static final String DEFAULT_DELIMITER = "\n\n";

private String delimiter;

public DelimiterChunker(Map<String, Object> parameters) {
validateAndParseParameters(parameters);
}

/**
* Validate the chunked passages for delimiter algorithm
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
*/
public class FixedTokenLengthChunker implements Chunker {

public static final String ALGORITHM_NAME = "fixed_token_length";
public static final String ANALYSIS_REGISTRY_FIELD = "analysis_registry";
public static final String TOKEN_LIMIT_FIELD = "token_limit";
public static final String OVERLAP_RATE_FIELD = "overlap_rate";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
import org.opensearch.indices.analysis.AnalysisModule;
import org.opensearch.ingest.IngestDocument;
import org.opensearch.ingest.Processor;
import org.opensearch.neuralsearch.processor.chunker.ChunkerFactory;
import org.opensearch.neuralsearch.processor.chunker.DelimiterChunker;
import org.opensearch.neuralsearch.processor.chunker.FixedTokenLengthChunker;
import org.opensearch.neuralsearch.processor.factory.TextChunkingProcessorFactory;
Expand Down Expand Up @@ -128,7 +127,7 @@ private Map<String, Object> createNestedFieldMap() {
private TextChunkingProcessor createFixedTokenLengthInstance(Map<String, Object> fieldMap) {
Map<String, Object> config = new HashMap<>();
Map<String, Object> algorithmMap = new HashMap<>();
algorithmMap.put(ChunkerFactory.FIXED_TOKEN_LENGTH_ALGORITHM, createFixedTokenLengthParameters());
algorithmMap.put(FixedTokenLengthChunker.ALGORITHM_NAME, createFixedTokenLengthParameters());
config.put(FIELD_MAP_FIELD, fieldMap);
config.put(ALGORITHM_FIELD, algorithmMap);
Map<String, Processor.Factory> registry = new HashMap<>();
Expand All @@ -139,7 +138,7 @@ private TextChunkingProcessor createFixedTokenLengthInstance(Map<String, Object>
private TextChunkingProcessor createFixedTokenLengthInstanceWithMaxChunkNum(Map<String, Object> fieldMap, int maxChunkNum) {
Map<String, Object> config = new HashMap<>();
Map<String, Object> algorithmMap = new HashMap<>();
algorithmMap.put(ChunkerFactory.FIXED_TOKEN_LENGTH_ALGORITHM, createFixedTokenLengthParametersWithMaxChunk(maxChunkNum));
algorithmMap.put(FixedTokenLengthChunker.ALGORITHM_NAME, createFixedTokenLengthParametersWithMaxChunk(maxChunkNum));
config.put(FIELD_MAP_FIELD, fieldMap);
config.put(ALGORITHM_FIELD, algorithmMap);
Map<String, Processor.Factory> registry = new HashMap<>();
Expand All @@ -151,7 +150,7 @@ private TextChunkingProcessor createDelimiterInstance() {
Map<String, Object> config = new HashMap<>();
Map<String, Object> fieldMap = new HashMap<>();
Map<String, Object> algorithmMap = new HashMap<>();
algorithmMap.put(ChunkerFactory.DELIMITER_ALGORITHM, createDelimiterParameters());
algorithmMap.put(DelimiterChunker.ALGORITHM_NAME, createDelimiterParameters());
fieldMap.put(INPUT_FIELD, OUTPUT_FIELD);
config.put(FIELD_MAP_FIELD, fieldMap);
config.put(ALGORITHM_FIELD, algorithmMap);
Expand All @@ -178,7 +177,7 @@ public void testCreate_whenMaxChunkNumInvalidValue_thenFail() {
Map<String, Object> fieldMap = new HashMap<>();
Map<String, Object> algorithmMap = new HashMap<>();
fieldMap.put(INPUT_FIELD, OUTPUT_FIELD);
algorithmMap.put(ChunkerFactory.FIXED_TOKEN_LENGTH_ALGORITHM, createFixedTokenLengthParametersWithMaxChunk(-2));
algorithmMap.put(FixedTokenLengthChunker.ALGORITHM_NAME, createFixedTokenLengthParametersWithMaxChunk(-2));
config.put(FIELD_MAP_FIELD, fieldMap);
config.put(ALGORITHM_FIELD, algorithmMap);
IllegalArgumentException illegalArgumentException = assertThrows(
Expand Down Expand Up @@ -213,8 +212,8 @@ public void testCreate_whenAlgorithmFieldMultipleAlgorithm_thenFail() {
Map<String, Object> algorithmMap = new HashMap<>();
fieldMap.put(INPUT_FIELD, OUTPUT_FIELD);
config.put(TextChunkingProcessor.FIELD_MAP_FIELD, fieldMap);
algorithmMap.put(ChunkerFactory.FIXED_TOKEN_LENGTH_ALGORITHM, createFixedTokenLengthParameters());
algorithmMap.put(ChunkerFactory.DELIMITER_ALGORITHM, createDelimiterParameters());
algorithmMap.put(FixedTokenLengthChunker.ALGORITHM_NAME, createFixedTokenLengthParameters());
algorithmMap.put(DelimiterChunker.ALGORITHM_NAME, createDelimiterParameters());
config.put(ALGORITHM_FIELD, algorithmMap);
Map<String, Processor.Factory> registry = new HashMap<>();
IllegalArgumentException illegalArgumentException = assertThrows(
Expand Down Expand Up @@ -251,7 +250,7 @@ public void testCreate_whenAlgorithmFieldInvalidAlgorithmContent_thenFail() {
Map<String, Object> algorithmMap = new HashMap<>();
fieldMap.put(INPUT_FIELD, OUTPUT_FIELD);
config.put(TextChunkingProcessor.FIELD_MAP_FIELD, fieldMap);
algorithmMap.put(ChunkerFactory.FIXED_TOKEN_LENGTH_ALGORITHM, 1);
algorithmMap.put(FixedTokenLengthChunker.ALGORITHM_NAME, 1);
config.put(ALGORITHM_FIELD, algorithmMap);
Map<String, Processor.Factory> registry = new HashMap<>();
IllegalArgumentException illegalArgumentException = assertThrows(
Expand All @@ -260,7 +259,7 @@ public void testCreate_whenAlgorithmFieldInvalidAlgorithmContent_thenFail() {
);
assertEquals(
"Unable to create the processor as ["
+ ChunkerFactory.FIXED_TOKEN_LENGTH_ALGORITHM
+ FixedTokenLengthChunker.ALGORITHM_NAME
+ "] parameters cannot be cast to ["
+ Map.class.getName()
+ "]",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,18 @@ public class ChunkerFactoryTests extends OpenSearchTestCase {
private AnalysisRegistry analysisRegistry;

public void testGetAllChunkers() {
Set<String> expected = Set.of(ChunkerFactory.FIXED_TOKEN_LENGTH_ALGORITHM, ChunkerFactory.DELIMITER_ALGORITHM);
Set<String> expected = Set.of(FixedTokenLengthChunker.ALGORITHM_NAME, DelimiterChunker.ALGORITHM_NAME);
assertEquals(expected, ChunkerFactory.getAllChunkers());
}

public void testCreate_FixedTokenLength() {
Chunker chunker = ChunkerFactory.create(ChunkerFactory.FIXED_TOKEN_LENGTH_ALGORITHM, createChunkParameters());
Chunker chunker = ChunkerFactory.create(FixedTokenLengthChunker.ALGORITHM_NAME, createChunkParameters());
assertNotNull(chunker);
assertTrue(chunker instanceof FixedTokenLengthChunker);
}

public void testCreate_Delimiter() {
Chunker chunker = ChunkerFactory.create(ChunkerFactory.DELIMITER_ALGORITHM, createChunkParameters());
Chunker chunker = ChunkerFactory.create(DelimiterChunker.ALGORITHM_NAME, createChunkParameters());
assertNotNull(chunker);
assertTrue(chunker instanceof DelimiterChunker);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import org.opensearch.indices.analysis.AnalysisModule;
import org.opensearch.ingest.Processor;
import org.opensearch.neuralsearch.processor.TextChunkingProcessor;
import org.opensearch.neuralsearch.processor.chunker.ChunkerFactory;
import org.opensearch.neuralsearch.processor.chunker.FixedTokenLengthChunker;
import org.opensearch.plugins.AnalysisPlugin;
import org.opensearch.test.OpenSearchTestCase;
import static org.opensearch.neuralsearch.processor.TextChunkingProcessor.TYPE;
Expand All @@ -34,7 +34,7 @@ public class TextChunkingProcessorFactoryTests extends OpenSearchTestCase {

private static final String PROCESSOR_TAG = "mockTag";
private static final String DESCRIPTION = "mockDescription";
private static final Map<String, Object> algorithmMap = Map.of(ChunkerFactory.FIXED_TOKEN_LENGTH_ALGORITHM, new HashMap<>());
private static final Map<String, Object> algorithmMap = Map.of(FixedTokenLengthChunker.ALGORITHM_NAME, new HashMap<>());

private TextChunkingProcessorFactory textChunkingProcessorFactory;

Expand Down

0 comments on commit 906cf73

Please sign in to comment.