Skip to content

Commit

Permalink
basic unit tests for document chunking processor
Browse files Browse the repository at this point in the history
Signed-off-by: yuye-aws <[email protected]>
  • Loading branch information
yuye-aws committed Feb 27, 2024
1 parent 51905df commit d187cbb
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,14 @@
import org.opensearch.index.mapper.IndexFieldMapper;

import static org.opensearch.ingest.ConfigurationUtils.readMap;
import static org.opensearch.neuralsearch.processor.InferenceProcessor.FIELD_MAP_FIELD;

public final class DocumentChunkingProcessor extends AbstractProcessor {

public static final String TYPE = "chunking";
public static final String OUTPUT_FIELD = "output_field";

public static final String FIELD_MAP_FIELD = "field_map";

private final Map<String, Object> fieldMap;

private final Set<String> supportedChunkers = ChunkerFactory.getAllChunkers();
Expand Down Expand Up @@ -129,7 +130,6 @@ private void validateDocumentChunkingFieldMap(Map<String, Object> fieldMap) {
private void validateContent(Object content, String inputField) {
// content can be a map, a list of strings or a list
if (content instanceof Map) {
System.out.println("map type");
@SuppressWarnings("unchecked")
Map<String, Object> contentMap = (Map<String, Object>) content;
for (Map.Entry<String, Object> contentEntry : contentMap.entrySet()) {
Expand Down Expand Up @@ -214,7 +214,7 @@ public IngestDocument execute(IngestDocument document) {
@SuppressWarnings("unchecked")
Map<String, Object> chunkerParameters = (Map<String, Object>) parameterEntry.getValue();
if (Objects.equals(parameterKey, ChunkerFactory.FIXED_LENGTH_ALGORITHM)) {
// add maxTokenCount to chunker parameters
// for fixed token length algorithm, add maxTokenCount to chunker parameters
Map<String, Object> sourceAndMetadataMap = document.getSourceAndMetadata();
int maxTokenCount = IndexSettings.MAX_TOKEN_COUNT_SETTING.get(settings);
String indexName = sourceAndMetadataMap.get(IndexFieldMapper.NAME).toString();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,126 @@
*/
package org.opensearch.neuralsearch.processor;

public class DocumentChunkingProcessorTests {}
import lombok.SneakyThrows;
import org.apache.lucene.tests.analysis.MockTokenizer;
import org.junit.Before;
import org.opensearch.cluster.ClusterState;
import org.opensearch.cluster.metadata.Metadata;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.env.Environment;
import org.opensearch.env.TestEnvironment;
import org.opensearch.index.analysis.AnalysisRegistry;
import org.opensearch.index.analysis.TokenizerFactory;
import org.opensearch.index.mapper.IndexFieldMapper;
import org.opensearch.indices.IndicesService;
import org.opensearch.indices.analysis.AnalysisModule;
import org.opensearch.ingest.IngestDocument;
import org.opensearch.ingest.Processor;
import org.opensearch.neuralsearch.processor.chunker.ChunkerFactory;
import org.opensearch.neuralsearch.processor.chunker.FixedTokenLengthChunker;
import org.opensearch.plugins.AnalysisPlugin;
import org.opensearch.test.OpenSearchTestCase;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import static java.util.Collections.singletonList;
import static java.util.Collections.singletonMap;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.when;
import static org.mockito.Mockito.mock;

public class DocumentChunkingProcessorTests extends OpenSearchTestCase {

private DocumentChunkingProcessor.Factory factory;

private static final String PROCESSOR_TAG = "mockTag";
private static final String DESCRIPTION = "mockDescription";

@SneakyThrows
private AnalysisRegistry getAnalysisRegistry() {
Settings settings = Settings.builder().put(Environment.PATH_HOME_SETTING.getKey(), createTempDir().toString()).build();
Environment environment = TestEnvironment.newEnvironment(settings);
AnalysisPlugin plugin = new AnalysisPlugin() {

@Override
public Map<String, AnalysisModule.AnalysisProvider<TokenizerFactory>> getTokenizers() {
return singletonMap(
"keyword",
(indexSettings, environment, name, settings) -> TokenizerFactory.newFactory(
name,
() -> new MockTokenizer(MockTokenizer.KEYWORD, false)
)
);
}
};
return new AnalysisModule(environment, singletonList(plugin)).getAnalysisRegistry();
}

@Before
public void setup() {
Settings settings = Settings.builder().build();
Metadata metadata = mock(Metadata.class);
ClusterState clusterState = mock(ClusterState.class);
ClusterService clusterService = mock(ClusterService.class);
IndicesService indicesService = mock(IndicesService.class);
when(metadata.index(anyString())).thenReturn(null);
when(clusterState.metadata()).thenReturn(metadata);
when(clusterService.state()).thenReturn(clusterState);
factory = new DocumentChunkingProcessor.Factory(settings, clusterService, indicesService, getAnalysisRegistry());
}

@SneakyThrows
public void testGetType() {
DocumentChunkingProcessor processor = createFixedTokenLengthInstance();
String type = processor.getType();
assertEquals(DocumentChunkingProcessor.TYPE, type);
}

private Map<String, Object> createFixedTokenLengthParameters() {
Map<String, Object> parameters = new HashMap<>();
parameters.put(FixedTokenLengthChunker.TOKEN_LIMIT, 10);
return parameters;
}

@SneakyThrows
private DocumentChunkingProcessor createFixedTokenLengthInstance() {
Map<String, Object> config = new HashMap<>();
Map<String, Object> fieldParameters = new HashMap<>();
Map<String, Object> chunkerParameters = new HashMap<>();
chunkerParameters.put(ChunkerFactory.FIXED_LENGTH_ALGORITHM, createFixedTokenLengthParameters());
chunkerParameters.put(DocumentChunkingProcessor.OUTPUT_FIELD, "body_chunk");
fieldParameters.put("body", chunkerParameters);
config.put(DocumentChunkingProcessor.FIELD_MAP_FIELD, fieldParameters);
Map<String, Processor.Factory> registry = new HashMap<>();
return factory.create(registry, PROCESSOR_TAG, DESCRIPTION, config);
}

private IngestDocument createIngestDocument() {
Map<String, Object> sourceAndMetadata = new HashMap<>();
sourceAndMetadata.put(
"body",
"This is an example document to be chunked. The document contains a single paragraph, two sentences and 24 tokens by standard tokenizer in OpenSearch."
);
sourceAndMetadata.put(IndexFieldMapper.NAME, "_index");
return new IngestDocument(sourceAndMetadata, new HashMap<>());
}

@SneakyThrows
public void testExecute_withFixedTokenLength_successful() {
DocumentChunkingProcessor processor = createFixedTokenLengthInstance();
IngestDocument ingestDocument = createIngestDocument();
IngestDocument document = processor.execute(ingestDocument);
assert document.getSourceAndMetadata().containsKey("body_chunk");
Object passages = document.getSourceAndMetadata().get("body_chunk");
assert (passages instanceof List<?>);
List<String> expectedPassages = new ArrayList<>();
expectedPassages.add("This is an example document to be chunked The document");
expectedPassages.add("The document contains a single paragraph two sentences and 24");
expectedPassages.add("and 24 tokens by standard tokenizer in OpenSearch");
assertEquals(expectedPassages, passages);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,29 +13,29 @@
public class ChunkerFactoryTests extends OpenSearchTestCase {

@Mock
private AnalysisRegistry registry;
private AnalysisRegistry analysisRegistry;

public void testGetAllChunkers() {
Set<String> expected = Set.of(ChunkerFactory.FIXED_LENGTH_ALGORITHM, ChunkerFactory.DELIMITER_ALGORITHM);
assertEquals(expected, ChunkerFactory.getAllChunkers());
}

public void testCreate_FixedTokenLength() {
IFieldChunker chunker = ChunkerFactory.create(ChunkerFactory.FIXED_LENGTH_ALGORITHM, registry);
IFieldChunker chunker = ChunkerFactory.create(ChunkerFactory.FIXED_LENGTH_ALGORITHM, analysisRegistry);
assertNotNull(chunker);
assertTrue(chunker instanceof FixedTokenLengthChunker);
}

public void testCreate_Delimiter() {
IFieldChunker chunker = ChunkerFactory.create(ChunkerFactory.DELIMITER_ALGORITHM, registry);
IFieldChunker chunker = ChunkerFactory.create(ChunkerFactory.DELIMITER_ALGORITHM, analysisRegistry);
assertNotNull(chunker);
assertTrue(chunker instanceof DelimiterChunker);
}

public void testCreate_Invalid() {
IllegalArgumentException illegalArgumentException = assertThrows(
IllegalArgumentException.class,
() -> ChunkerFactory.create("Invalid Chunker Type", registry)
() -> ChunkerFactory.create("Invalid Chunker Type", analysisRegistry)
);
assertEquals(
"chunker type [Invalid Chunker Type] is not supported. Supported chunkers types are [fix_length, delimiter]",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ public class FixedTokenLengthChunkerTests extends OpenSearchTestCase {
public void setup() {
Settings settings = Settings.builder().put(Environment.PATH_HOME_SETTING.getKey(), createTempDir().toString()).build();
Environment environment = TestEnvironment.newEnvironment(settings);

AnalysisPlugin plugin = new AnalysisPlugin() {

@Override
Expand Down

0 comments on commit d187cbb

Please sign in to comment.