Skip to content

Commit

Permalink
implement overlap rate with big decimal
Browse files Browse the repository at this point in the history
Signed-off-by: yuye-aws <[email protected]>
  • Loading branch information
yuye-aws committed Mar 7, 2024
1 parent 8db712b commit 0b27618
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,5 @@ public DocumentChunkingProcessor create(
analysisRegistry
);
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
package org.opensearch.neuralsearch.processor.chunker;

import java.io.IOException;
import java.math.RoundingMode;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.math.BigDecimal;

import lombok.extern.log4j.Log4j2;
import org.opensearch.action.admin.indices.analyze.AnalyzeAction;
Expand Down Expand Up @@ -57,7 +59,7 @@ private List<String> tokenize(String content, String tokenizer, int maxTokenCoun
public List<String> chunk(String content, Map<String, Object> parameters) {
// prior to chunking, parameters have been validated
int tokenLimit = DEFAULT_TOKEN_LIMIT;
double overlapRate = DEFAULT_OVERLAP_RATE;
BigDecimal overlap_rate = new BigDecimal(String.valueOf(DEFAULT_OVERLAP_RATE));
int maxTokenCount = DEFAULT_MAX_TOKEN_COUNT;
int maxChunkLimit = DEFAULT_MAX_CHUNK_LIMIT;

Expand All @@ -67,7 +69,7 @@ public List<String> chunk(String content, Map<String, Object> parameters) {
tokenLimit = ((Number) parameters.get(TOKEN_LIMIT_FIELD)).intValue();
}
if (parameters.containsKey(OVERLAP_RATE_FIELD)) {
overlapRate = ((Number) parameters.get(OVERLAP_RATE_FIELD)).doubleValue();
overlap_rate = new BigDecimal(String.valueOf(parameters.get(OVERLAP_RATE_FIELD)));
}
if (parameters.containsKey(MAX_TOKEN_COUNT_FIELD)) {
maxTokenCount = ((Number) parameters.get(MAX_TOKEN_COUNT_FIELD)).intValue();
Expand All @@ -84,7 +86,8 @@ public List<String> chunk(String content, Map<String, Object> parameters) {

String passage;
int startToken = 0;
int overlapTokenNumber = (int) Math.floor(tokenLimit * overlapRate);
BigDecimal overlapTokenNumberBigDecimal = overlap_rate.multiply(new BigDecimal(String.valueOf(tokenLimit))).setScale(0, RoundingMode.DOWN);
int overlapTokenNumber = overlapTokenNumberBigDecimal.intValue();;
// overlapTokenNumber must be smaller than the token limit
overlapTokenNumber = Math.min(overlapTokenNumber, tokenLimit - 1);

Expand Down Expand Up @@ -144,8 +147,8 @@ public void validateParameters(Map<String, Object> parameters) {
"fixed length parameter [" + OVERLAP_RATE_FIELD + "] cannot be cast to [" + Number.class.getName() + "]"
);
}
if (((Number) parameters.get(OVERLAP_RATE_FIELD)).doubleValue() < 0.0
|| ((Number) parameters.get(OVERLAP_RATE_FIELD)).doubleValue() >= 1.0) {
BigDecimal overlap_rate = new BigDecimal(String.valueOf(parameters.get(OVERLAP_RATE_FIELD)));
if (overlap_rate.compareTo(BigDecimal.ZERO) < 0 || overlap_rate.compareTo(BigDecimal.ONE) >= 0) {
throw new IllegalArgumentException(
"fixed length parameter [" + OVERLAP_RATE_FIELD + "] must be between 0 and 1, 1 is not included."
);
Expand Down

0 comments on commit 0b27618

Please sign in to comment.