Skip to content

Commit

Permalink
fix Uts
Browse files Browse the repository at this point in the history
Signed-off-by: xinyual <[email protected]>
  • Loading branch information
xinyual committed Mar 11, 2024
1 parent 75badf7 commit eb4f36b
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,6 @@
import java.util.Map;
import java.util.Set;

import org.opensearch.index.analysis.AnalysisRegistry;

import static org.opensearch.neuralsearch.processor.chunker.FixedTokenLengthChunker.ANALYSIS_REGISTRY_FIELD;

/**
* A factory to create different chunking algorithm classes and return all supported chunking algorithms.
*/
Expand All @@ -22,7 +18,7 @@ public class ChunkerFactory {
public static Chunker create(String type, Map<String, Object> parameters) {
switch (type) {
case FIXED_TOKEN_LENGTH_ALGORITHM:
return new FixedTokenLengthChunker((AnalysisRegistry) parameters.get(ANALYSIS_REGISTRY_FIELD), parameters);
return new FixedTokenLengthChunker(parameters);
case DELIMITER_ALGORITHM:
return new DelimiterChunker(parameters);
default:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,15 @@ public class FixedTokenLengthChunker implements Chunker {

// default values for each parameter
private static final int DEFAULT_TOKEN_LIMIT = 384;
private static final Double DEFAULT_OVERLAP_RATE = 0.0;
private static final double DEFAULT_OVERLAP_RATE = 0.0;
private static final int DEFAULT_MAX_TOKEN_COUNT = 10000;
private static final String DEFAULT_TOKENIZER = "standard";

private static final String DEFAULT_TOKEN_CONCATENATOR = " ";

private static final Double OVERLAP_RATE_UPPER_BOUND = 0.5;
private static final double OVERLAP_RATE_UPPER_BOUND = 0.5;

private Double overlapRate;
private double overlapRate;

private int tokenLimit;

Expand All @@ -50,9 +50,9 @@ public class FixedTokenLengthChunker implements Chunker {

private final AnalysisRegistry analysisRegistry;

public FixedTokenLengthChunker(AnalysisRegistry analysisRegistry, Map<String, Object> parameters) {
public FixedTokenLengthChunker(Map<String, Object> parameters) {
validateParameters(parameters);
this.analysisRegistry = analysisRegistry;
this.analysisRegistry = (AnalysisRegistry) parameters.get(ANALYSIS_REGISTRY_FIELD);
}

/**
Expand All @@ -79,8 +79,8 @@ public void validateParameters(Map<String, Object> parameters) {
"fixed length parameter [" + OVERLAP_RATE_FIELD + "] cannot be cast to [" + Number.class.getName() + "]"
);
}
Double overlapRate = Double.valueOf(overlapRateString);
if (overlapRate < 0 || overlapRate.compareTo(OVERLAP_RATE_UPPER_BOUND) > 0) {
double overlapRate = NumberUtils.createDouble(overlapRateString);
if (overlapRate < 0 || overlapRate > OVERLAP_RATE_UPPER_BOUND) {
throw new IllegalArgumentException(
"fixed length parameter [" + OVERLAP_RATE_FIELD + "] must be between 0 and " + OVERLAP_RATE_UPPER_BOUND
);
Expand All @@ -103,12 +103,13 @@ public static String validateStringParameters(
// all parameters are optional
return defaultValue;
}
if (!(parameters.get(fieldName) instanceof String)) {
Object fieldValue = parameters.get(fieldName);
if (!(fieldValue instanceof String)) {
throw new IllegalArgumentException("Chunker parameter [" + fieldName + "] cannot be cast to [" + String.class.getName() + "]");
} else if (StringUtils.isEmpty(parameters.get(fieldName).toString()) && !allowEmpty) {
} else if (!allowEmpty && StringUtils.isEmpty(fieldValue.toString())) {
throw new IllegalArgumentException("Chunker parameter: " + fieldName + " should not be empty.");
}
return (String) parameters.get(fieldName);
return (String) fieldValue;
}

private int validatePositiveIntegerParameter(Map<String, Object> parameters, String fieldName, int defaultValue) {
Expand All @@ -133,16 +134,13 @@ private int validatePositiveIntegerParameter(Map<String, Object> parameters, Str
* Return the chunked passages for fixed token length algorithm
*
* @param content input string
* @param parameters a map containing parameters, containing the following parameters
* 1. tokenizer the <a href="https://opensearch.org/docs/latest/analyzers/tokenizers/index/">analyzer tokenizer</a> in OpenSearch
* 2. token_limit the token limit for each chunked passage
* 3. overlap_rate the overlapping degree for each chunked passage, indicating how many token comes from the previous passage
* 4. max_token_count the max token limit for the tokenizer
* @param runtimeParameters a map containing runtimeParameters, containing the following runtimeParameters
* max_token_count the max token limit for the tokenizer
*/
@Override
public List<String> chunk(String content, Map<String, Object> parameters) {
// prior to chunking, parameters have been validated
int maxTokenCount = validatePositiveIntegerParameter(parameters, MAX_TOKEN_COUNT_FIELD, DEFAULT_MAX_TOKEN_COUNT);
public List<String> chunk(String content, Map<String, Object> runtimeParameters) {
// prior to chunking, runtimeParameters have been validated
int maxTokenCount = validatePositiveIntegerParameter(runtimeParameters, MAX_TOKEN_COUNT_FIELD, DEFAULT_MAX_TOKEN_COUNT);

List<String> tokens = tokenize(content, tokenizer, maxTokenCount);
List<String> passages = new ArrayList<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import org.opensearch.index.analysis.AnalysisRegistry;
import org.opensearch.test.OpenSearchTestCase;

import java.util.HashMap;
import java.util.Map;
import java.util.Set;

Expand All @@ -24,13 +25,13 @@ public void testGetAllChunkers() {
}

public void testCreate_FixedTokenLength() {
Chunker chunker = ChunkerFactory.create(ChunkerFactory.FIXED_TOKEN_LENGTH_ALGORITHM, Map.of(ANALYSIS_REGISTRY_FIELD, analysisRegistry));
Chunker chunker = ChunkerFactory.create(ChunkerFactory.FIXED_TOKEN_LENGTH_ALGORITHM, createChunkParameters());
assertNotNull(chunker);
assertTrue(chunker instanceof FixedTokenLengthChunker);
}

public void testCreate_Delimiter() {
Chunker chunker = ChunkerFactory.create(ChunkerFactory.DELIMITER_ALGORITHM, Map.of(ANALYSIS_REGISTRY_FIELD, analysisRegistry));
Chunker chunker = ChunkerFactory.create(ChunkerFactory.DELIMITER_ALGORITHM, createChunkParameters());
assertNotNull(chunker);
assertTrue(chunker instanceof DelimiterChunker);
}
Expand All @@ -39,8 +40,14 @@ public void testCreate_Invalid() {
String invalidChunkerType = "Invalid Chunker Type";
IllegalArgumentException illegalArgumentException = assertThrows(
IllegalArgumentException.class,
() -> ChunkerFactory.create(invalidChunkerType, Map.of(ANALYSIS_REGISTRY_FIELD, analysisRegistry))
() -> ChunkerFactory.create(invalidChunkerType, createChunkParameters())
);
assert (illegalArgumentException.getMessage().contains("chunker type [" + invalidChunkerType + "] is not supported."));
}

private Map<String, Object> createChunkParameters() {
Map<String, Object> parameters = new HashMap<>();
parameters.put(ANALYSIS_REGISTRY_FIELD, analysisRegistry);
return parameters;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@

import static java.util.Collections.singletonList;
import static java.util.Collections.singletonMap;
import static org.opensearch.neuralsearch.processor.chunker.FixedTokenLengthChunker.TOKENIZER_FIELD;
import static org.opensearch.neuralsearch.processor.chunker.FixedTokenLengthChunker.ANALYSIS_REGISTRY_FIELD;
import static org.opensearch.neuralsearch.processor.chunker.FixedTokenLengthChunker.TOKEN_LIMIT_FIELD;
import static org.opensearch.neuralsearch.processor.chunker.FixedTokenLengthChunker.OVERLAP_RATE_FIELD;
import static org.opensearch.neuralsearch.processor.chunker.FixedTokenLengthChunker.TOKENIZER_FIELD;
import static org.opensearch.neuralsearch.processor.chunker.FixedTokenLengthChunker.MAX_TOKEN_COUNT_FIELD;

public class FixedTokenLengthChunkerTests extends OpenSearchTestCase {
Expand All @@ -40,6 +41,8 @@ public void setup() {

@SneakyThrows
public FixedTokenLengthChunker createFixedTokenLengthChunker(Map<String, Object> parameters) {
Map<String, Object> nonruntimeParameters = new HashMap<>();
nonruntimeParameters.putAll(parameters);
Settings settings = Settings.builder().put(Environment.PATH_HOME_SETTING.getKey(), createTempDir().toString()).build();
Environment environment = TestEnvironment.newEnvironment(settings);
AnalysisPlugin plugin = new AnalysisPlugin() {
Expand All @@ -56,7 +59,8 @@ public Map<String, AnalysisModule.AnalysisProvider<TokenizerFactory>> getTokeniz
}
};
AnalysisRegistry analysisRegistry = new AnalysisModule(environment, singletonList(plugin)).getAnalysisRegistry();
return new FixedTokenLengthChunker(analysisRegistry, parameters);
nonruntimeParameters.put(ANALYSIS_REGISTRY_FIELD, analysisRegistry);
return new FixedTokenLengthChunker(nonruntimeParameters);
}

public void testValidateParameters_whenNoParams_thenSuccessful() {
Expand Down

0 comments on commit eb4f36b

Please sign in to comment.