Skip to content

Commit

Permalink
update chunker interface
Browse files Browse the repository at this point in the history
Signed-off-by: yuye-aws <[email protected]>
  • Loading branch information
yuye-aws committed Mar 11, 2024
1 parent f0475e2 commit c2dbc85
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 74 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
*/
package org.opensearch.neuralsearch.processor.chunker;

import com.google.common.collect.ImmutableList;

import java.util.ArrayList;
import java.util.Map;
import java.util.List;

Expand All @@ -25,8 +28,20 @@ public interface Chunker {
* Chunk the incoming string according to parameters and return chunked passages
*
* @param content input string
* @param parameters a map containing parameters for chunking algorithms
* @return Chunked passages
*/
List<String> chunk(String content, Map<String, Object> parameters);
default List<String> chunk(String content) {
return ImmutableList.of();
}

/**
* Chunk the incoming string according to parameters and return chunked passages
*
* @param content input string
* @param runtimeParameters a map containing runtime parameters for chunking algorithms
* @return Chunked passages
*/
default List<String> chunk(String content, Map<String, Object> runtimeParameters) {
return ImmutableList.of();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
*/
package org.opensearch.neuralsearch.processor.chunker;

import java.util.Map;
import java.util.Set;

import org.opensearch.index.analysis.AnalysisRegistry;
Expand All @@ -16,12 +17,12 @@ public class ChunkerFactory {
public static final String FIXED_TOKEN_LENGTH_ALGORITHM = "fixed_token_length";
public static final String DELIMITER_ALGORITHM = "delimiter";

public static Chunker create(String type, AnalysisRegistry analysisRegistry) {
public static Chunker create(String type, Map<String, Object> parameters) {
switch (type) {
case FIXED_TOKEN_LENGTH_ALGORITHM:
return new FixedTokenLengthChunker(analysisRegistry);
return new FixedTokenLengthChunker(parameters);
case DELIMITER_ALGORITHM:
return new DelimiterChunker();
return new DelimiterChunker(parameters);
default:
throw new IllegalArgumentException(
"chunker type [" + type + "] is not supported. Supported chunkers types are " + ChunkerFactory.getAllChunkers()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,16 @@
*/
public class DelimiterChunker implements Chunker {

public DelimiterChunker() {}
public DelimiterChunker(Map<String, Object> parameters) {
validateParameters(parameters);
}

public static final String DELIMITER_FIELD = "delimiter";

public static final String DEFAULT_DELIMITER = ".";

private String delimiter = DEFAULT_DELIMITER;

/**
* Validate the chunked passages for delimiter algorithm
*
Expand All @@ -31,11 +35,14 @@ public DelimiterChunker() {}
@Override
public void validateParameters(Map<String, Object> parameters) {
if (parameters.containsKey(DELIMITER_FIELD)) {
Object delimiter = parameters.get(DELIMITER_FIELD);
if (!(delimiter instanceof String)) {
throw new IllegalArgumentException("delimiter parameters: " + delimiter + " must be string.");
} else if (StringUtils.isBlank(delimiter.toString())) {
throw new IllegalArgumentException("delimiter parameters should not be empty.");
if (!(parameters.get(DELIMITER_FIELD) instanceof String)) {
throw new IllegalArgumentException(
"delimiter parameter [" + DELIMITER_FIELD + "] cannot be cast to [" + String.class.getName() + "]"
);
}
this.delimiter = parameters.get(DELIMITER_FIELD).toString();
if (StringUtils.isBlank(delimiter)) {
throw new IllegalArgumentException("delimiter parameter [" + DELIMITER_FIELD + "] should not be empty.");
}
}
}
Expand All @@ -44,11 +51,9 @@ public void validateParameters(Map<String, Object> parameters) {
* Return the chunked passages for delimiter algorithm
*
* @param content input string
* @param parameters a map containing parameters, containing the following parameters
*/
@Override
public List<String> chunk(String content, Map<String, Object> parameters) {
String delimiter = DEFAULT_DELIMITER;
public List<String> chunk(String content) {
if (parameters.containsKey(DELIMITER_FIELD)) {
Object delimiterObject = parameters.get(DELIMITER_FIELD);
delimiter = delimiterObject.toString();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,20 @@
import java.util.Map;
import java.util.List;
import java.util.ArrayList;

import lombok.Setter;
import lombok.extern.log4j.Log4j2;
import org.apache.commons.lang3.math.NumberUtils;

import org.opensearch.cluster.metadata.IndexMetadata;
import org.opensearch.index.IndexService;
import org.opensearch.index.IndexSettings;
import org.opensearch.index.analysis.AnalysisRegistry;
import org.opensearch.action.admin.indices.analyze.AnalyzeAction;
import org.opensearch.action.admin.indices.analyze.AnalyzeAction.AnalyzeToken;
import org.opensearch.index.mapper.IndexFieldMapper;
import org.opensearch.ingest.IngestDocument;

import static org.opensearch.action.admin.indices.analyze.TransportAnalyzeAction.analyze;

/**
Expand All @@ -27,40 +35,66 @@ public class FixedTokenLengthChunker implements Chunker {
public static final String TOKEN_LIMIT_FIELD = "token_limit";
public static final String OVERLAP_RATE_FIELD = "overlap_rate";
public static final String MAX_TOKEN_COUNT_FIELD = "max_token_count";
public static final String ANALYSIS_REGISTRY_FIELD = "analysis_registry";
public static final String TOKENIZER_FIELD = "tokenizer";

// default values for each parameter
private static final int DEFAULT_TOKEN_LIMIT = 500;
private static final int DEFAULT_TOKEN_LIMIT = 384;
private static final BigDecimal DEFAULT_OVERLAP_RATE = new BigDecimal("0");
private static final int DEFAULT_MAX_TOKEN_COUNT = 10000;
private static final String DEFAULT_TOKENIZER = "standard";

private static final BigDecimal OVERLAP_RATE_UPPER_BOUND = new BigDecimal("0.5");

private final AnalysisRegistry analysisRegistry;
private int tokenLimit = DEFAULT_TOKEN_LIMIT;
private BigDecimal overlapRate = DEFAULT_OVERLAP_RATE;
private String tokenizer = DEFAULT_TOKENIZER;
private AnalysisRegistry analysisRegistry;

public FixedTokenLengthChunker(AnalysisRegistry analysisRegistry) {
this.analysisRegistry = analysisRegistry;
public FixedTokenLengthChunker(Map<String, Object> parameters) {
validateParameters(parameters);
}

/**
* Validate the chunked passages for fixed token length algorithm,
* Validate and parse the parameters for fixed token length algorithm,
* will throw IllegalArgumentException when parameters are invalid
*
* @param parameters a map containing parameters, containing the following parameters:
* 1. tokenizer the analyzer tokenizer in opensearch, please check https://opensearch.org/docs/latest/analyzers/tokenizers/index/
* 2. token_limit the token limit for each chunked passage
* 3. overlap_rate the overlapping degree for each chunked passage, indicating how many token comes from the previous passage
* 4. max_token_count the max token limit for the tokenizer
* 1. tokenizer: the <a href="https://opensearch.org/docs/latest/analyzers/tokenizers/index/">analyzer tokenizer</a> in opensearch
* 2. token_limit: the token limit for each chunked passage
* 3. overlap_rate: the overlapping degree for each chunked passage, indicating how many token comes from the previous passage
* 4. max_token_count: the max token limit for the tokenizer
* Here are requirements for parameters:
* max_token_count and token_limit should be a positive integer
* overlap_rate should be within range [0, 0.5]
* tokenizer should be string
*/
@Override
public void validateParameters(Map<String, Object> parameters) {
validatePositiveIntegerParameter(parameters, TOKEN_LIMIT_FIELD);
validatePositiveIntegerParameter(parameters, MAX_TOKEN_COUNT_FIELD);
if (parameters.containsKey(TOKEN_LIMIT_FIELD)) {
String tokenLimitString = parameters.get(TOKEN_LIMIT_FIELD).toString();
if (!(NumberUtils.isParsable(tokenLimitString))) {
throw new IllegalArgumentException(
"fixed length parameter [" + TOKEN_LIMIT_FIELD + "] cannot be cast to [" + Number.class.getName() + "]"
);
}
this.tokenLimit = NumberUtils.createInteger(tokenLimitString);
if (tokenLimit <= 0) {
throw new IllegalArgumentException("fixed length parameter [" + TOKEN_LIMIT_FIELD + "] must be positive");
}
}

if (parameters.containsKey(MAX_TOKEN_COUNT_FIELD)) {
String maxTokenCountString = parameters.get(MAX_TOKEN_COUNT_FIELD).toString();
if (!(NumberUtils.isParsable(maxTokenCountString))) {
throw new IllegalArgumentException(
"fixed length parameter [" + MAX_TOKEN_COUNT_FIELD + "] cannot be cast to [" + Number.class.getName() + "]"
);
}
this.maxTokenCount = NumberUtils.createInteger(maxTokenCountString);
if (maxTokenCount <= 0) {
throw new IllegalArgumentException("fixed length parameter [" + MAX_TOKEN_COUNT_FIELD + "] must be positive");
}
}

if (parameters.containsKey(OVERLAP_RATE_FIELD)) {
String overlapRateString = parameters.get(OVERLAP_RATE_FIELD).toString();
Expand All @@ -69,70 +103,31 @@ public void validateParameters(Map<String, Object> parameters) {
"fixed length parameter [" + OVERLAP_RATE_FIELD + "] cannot be cast to [" + Number.class.getName() + "]"
);
}
BigDecimal overlapRate = new BigDecimal(overlapRateString);
this.overlapRate = new BigDecimal(overlapRateString);
if (overlapRate.compareTo(BigDecimal.ZERO) < 0 || overlapRate.compareTo(OVERLAP_RATE_UPPER_BOUND) > 0) {
throw new IllegalArgumentException(
"fixed length parameter [" + OVERLAP_RATE_FIELD + "] must be between 0 and " + OVERLAP_RATE_UPPER_BOUND
);
}
}

if (parameters.containsKey(TOKENIZER_FIELD) && !(parameters.get(TOKENIZER_FIELD) instanceof String)) {
throw new IllegalArgumentException(
"fixed length parameter [" + TOKENIZER_FIELD + "] cannot be cast to [" + String.class.getName() + "]"
);
}
}

private void validatePositiveIntegerParameter(Map<String, Object> parameters, String fieldName) {
// this method validate that parameter is a positive integer
if (!parameters.containsKey(fieldName)) {
// all parameters are optional
return;
}
String fieldValue = parameters.get(fieldName).toString();
if (!(NumberUtils.isParsable(fieldValue))) {
throw new IllegalArgumentException(
"fixed length parameter [" + fieldName + "] cannot be cast to [" + Number.class.getName() + "]"
);
}
if (NumberUtils.createInteger(fieldValue) <= 0) {
throw new IllegalArgumentException("fixed length parameter [" + fieldName + "] must be positive");
if (parameters.containsKey(TOKENIZER_FIELD)) {
if (!(parameters.get(TOKENIZER_FIELD) instanceof String)) {
throw new IllegalArgumentException(
"fixed length parameter [" + TOKENIZER_FIELD + "] cannot be cast to [" + String.class.getName() + "]"
);
}
this.tokenizer = parameters.get(TOKENIZER_FIELD).toString();
}
}

/**
* Return the chunked passages for fixed token length algorithm
*
* @param content input string
* @param parameters a map containing parameters, containing the following parameters
* 1. tokenizer the <a href="https://opensearch.org/docs/latest/analyzers/tokenizers/index/">analyzer tokenizer</a> in OpenSearch
* 2. token_limit the token limit for each chunked passage
* 3. overlap_rate the overlapping degree for each chunked passage, indicating how many token comes from the previous passage
* 4. max_token_count the max token limit for the tokenizer
*/
@Override
public List<String> chunk(String content, Map<String, Object> parameters) {
// prior to chunking, parameters have been validated
int tokenLimit = DEFAULT_TOKEN_LIMIT;
BigDecimal overlapRate = DEFAULT_OVERLAP_RATE;
int maxTokenCount = DEFAULT_MAX_TOKEN_COUNT;

String tokenizer = DEFAULT_TOKENIZER;

if (parameters.containsKey(TOKEN_LIMIT_FIELD)) {
tokenLimit = ((Number) parameters.get(TOKEN_LIMIT_FIELD)).intValue();
}
if (parameters.containsKey(OVERLAP_RATE_FIELD)) {
overlapRate = new BigDecimal(parameters.get(OVERLAP_RATE_FIELD).toString());
}
if (parameters.containsKey(MAX_TOKEN_COUNT_FIELD)) {
maxTokenCount = ((Number) parameters.get(MAX_TOKEN_COUNT_FIELD)).intValue();
}
if (parameters.containsKey(TOKENIZER_FIELD)) {
tokenizer = (String) parameters.get(TOKENIZER_FIELD);
}

public List<String> chunk(String content, Map<String, Object> runtimeParameters) {
List<AnalyzeToken> tokens = tokenize(content, tokenizer, maxTokenCount);
List<String> passages = new ArrayList<>();

Expand Down

0 comments on commit c2dbc85

Please sign in to comment.