From 0b27618e20e8728f5b90b86f3e743abe902e3cad Mon Sep 17 00:00:00 2001 From: yuye-aws Date: Thu, 7 Mar 2024 13:37:35 +0800 Subject: [PATCH] implement overlap rate with big decimal Signed-off-by: yuye-aws --- .../processor/DocumentChunkingProcessor.java | 1 - .../processor/chunker/FixedTokenLengthChunker.java | 13 ++++++++----- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/DocumentChunkingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/DocumentChunkingProcessor.java index ed269af20..292388f0c 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/DocumentChunkingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/DocumentChunkingProcessor.java @@ -315,6 +315,5 @@ public DocumentChunkingProcessor create( analysisRegistry ); } - } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/chunker/FixedTokenLengthChunker.java b/src/main/java/org/opensearch/neuralsearch/processor/chunker/FixedTokenLengthChunker.java index 2f810d64d..98d9402f2 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/chunker/FixedTokenLengthChunker.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/chunker/FixedTokenLengthChunker.java @@ -5,9 +5,11 @@ package org.opensearch.neuralsearch.processor.chunker; import java.io.IOException; +import java.math.RoundingMode; import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.math.BigDecimal; import lombok.extern.log4j.Log4j2; import org.opensearch.action.admin.indices.analyze.AnalyzeAction; @@ -57,7 +59,7 @@ private List tokenize(String content, String tokenizer, int maxTokenCoun public List chunk(String content, Map parameters) { // prior to chunking, parameters have been validated int tokenLimit = DEFAULT_TOKEN_LIMIT; - double overlapRate = DEFAULT_OVERLAP_RATE; + BigDecimal overlap_rate = new BigDecimal(String.valueOf(DEFAULT_OVERLAP_RATE)); int maxTokenCount = DEFAULT_MAX_TOKEN_COUNT; int maxChunkLimit = DEFAULT_MAX_CHUNK_LIMIT; @@ -67,7 +69,7 @@ public List chunk(String content, Map parameters) { tokenLimit = ((Number) parameters.get(TOKEN_LIMIT_FIELD)).intValue(); } if (parameters.containsKey(OVERLAP_RATE_FIELD)) { - overlapRate = ((Number) parameters.get(OVERLAP_RATE_FIELD)).doubleValue(); + overlap_rate = new BigDecimal(String.valueOf(parameters.get(OVERLAP_RATE_FIELD))); } if (parameters.containsKey(MAX_TOKEN_COUNT_FIELD)) { maxTokenCount = ((Number) parameters.get(MAX_TOKEN_COUNT_FIELD)).intValue(); @@ -84,7 +86,8 @@ public List chunk(String content, Map parameters) { String passage; int startToken = 0; - int overlapTokenNumber = (int) Math.floor(tokenLimit * overlapRate); + BigDecimal overlapTokenNumberBigDecimal = overlap_rate.multiply(new BigDecimal(String.valueOf(tokenLimit))).setScale(0, RoundingMode.DOWN); + int overlapTokenNumber = overlapTokenNumberBigDecimal.intValue();; // overlapTokenNumber must be smaller than the token limit overlapTokenNumber = Math.min(overlapTokenNumber, tokenLimit - 1); @@ -144,8 +147,8 @@ public void validateParameters(Map parameters) { "fixed length parameter [" + OVERLAP_RATE_FIELD + "] cannot be cast to [" + Number.class.getName() + "]" ); } - if (((Number) parameters.get(OVERLAP_RATE_FIELD)).doubleValue() < 0.0 - || ((Number) parameters.get(OVERLAP_RATE_FIELD)).doubleValue() >= 1.0) { + BigDecimal overlap_rate = new BigDecimal(String.valueOf(parameters.get(OVERLAP_RATE_FIELD))); + if (overlap_rate.compareTo(BigDecimal.ZERO) < 0 || overlap_rate.compareTo(BigDecimal.ONE) >= 0) { throw new IllegalArgumentException( "fixed length parameter [" + OVERLAP_RATE_FIELD + "] must be between 0 and 1, 1 is not included." );