diff --git a/README.md b/README.md index 3ae61e23..dc3657d2 100644 --- a/README.md +++ b/README.md @@ -113,11 +113,9 @@ use text_splitter::MarkdownSplitter; let max_characters = 1000; // Default implementation uses character count for chunk size. // Can also use all of the same tokenizer implementations as `TextSplitter`. -let splitter = MarkdownSplitter::default() - // Optionally can also have the splitter trim whitespace for you - .with_trim_chunks(true); +let splitter = MarkdownSplitter::new(max_characters); -let chunks = splitter.chunks("# Header\n\nyour document text", max_characters); +let chunks = splitter.chunks("# Header\n\nyour document text"); ``` ## Method diff --git a/benches/chunk_size.rs b/benches/chunk_size.rs index 6fd36cd4..30d3d9d4 100644 --- a/benches/chunk_size.rs +++ b/benches/chunk_size.rs @@ -43,7 +43,7 @@ mod text { #[divan::bench(args = TEXT_FILENAMES, consts = CHUNK_SIZES)] fn characters(bencher: Bencher<'_, '_>, filename: &str) { - bench::<_, _>(bencher, filename, || TextSplitter::new(N)); + bench(bencher, filename, || TextSplitter::new(N)); } #[cfg(feature = "tiktoken-rs")] @@ -51,7 +51,7 @@ mod text { fn tiktoken(bencher: Bencher<'_, '_>, filename: &str) { use text_splitter::ChunkConfig; - bench::<_, _>(bencher, filename, || { + bench(bencher, filename, || { TextSplitter::new(ChunkConfig::new(N).with_sizer(tiktoken_rs::cl100k_base().unwrap())) }); } @@ -59,7 +59,7 @@ mod text { #[cfg(feature = "tokenizers")] #[divan::bench(args = TEXT_FILENAMES, consts = CHUNK_SIZES)] fn tokenizers(bencher: Bencher<'_, '_>, filename: &str) { - bench::<_, _>(bencher, filename, || { + bench(bencher, filename, || { TextSplitter::new(ChunkConfig::new(N).with_sizer( tokenizers::Tokenizer::from_pretrained("bert-base-cased", None).unwrap(), )) @@ -73,15 +73,15 @@ mod markdown { use std::fs; use divan::{black_box_drop, counter::BytesCount, Bencher}; - use text_splitter::{ChunkSizer, MarkdownSplitter}; + use text_splitter::{ChunkConfig, ChunkSizer, MarkdownSplitter}; use crate::CHUNK_SIZES; const MARKDOWN_FILENAMES: &[&str] = &["commonmark_spec"]; - fn bench(bencher: Bencher<'_, '_>, filename: &str, gen_splitter: G) + fn bench(bencher: Bencher<'_, '_>, filename: &str, gen_splitter: G) where - G: Fn() -> MarkdownSplitter + Sync, + G: Fn() -> MarkdownSplitter + Sync, S: ChunkSizer, { bencher @@ -93,30 +93,32 @@ mod markdown { }) .input_counter(|(_, text)| BytesCount::of_str(text)) .bench_values(|(splitter, text)| { - splitter.chunks(&text, N).for_each(black_box_drop); + splitter.chunks(&text).for_each(black_box_drop); }); } #[divan::bench(args = MARKDOWN_FILENAMES, consts = CHUNK_SIZES)] fn characters(bencher: Bencher<'_, '_>, filename: &str) { - bench::(bencher, filename, MarkdownSplitter::default); + bench(bencher, filename, || MarkdownSplitter::new(N)); } #[cfg(feature = "tiktoken-rs")] #[divan::bench(args = MARKDOWN_FILENAMES, consts = CHUNK_SIZES)] fn tiktoken(bencher: Bencher<'_, '_>, filename: &str) { - bench::(bencher, filename, || { - MarkdownSplitter::new(tiktoken_rs::cl100k_base().unwrap()) + bench(bencher, filename, || { + MarkdownSplitter::new( + ChunkConfig::new(N).with_sizer(tiktoken_rs::cl100k_base().unwrap()), + ) }); } #[cfg(feature = "tokenizers")] #[divan::bench(args = MARKDOWN_FILENAMES, consts = CHUNK_SIZES)] fn tokenizers(bencher: Bencher<'_, '_>, filename: &str) { - bench::(bencher, filename, || { - MarkdownSplitter::new( + bench(bencher, filename, || { + MarkdownSplitter::new(ChunkConfig::new(N).with_sizer( tokenizers::Tokenizer::from_pretrained("bert-base-cased", None).unwrap(), - ) + )) }); } } diff --git a/bindings/python/semantic_text_splitter.pyi b/bindings/python/semantic_text_splitter.pyi index 179f055b..b2df03db 100644 --- a/bindings/python/semantic_text_splitter.pyi +++ b/bindings/python/semantic_text_splitter.pyi @@ -127,7 +127,7 @@ class TextSplitter: up the chunk until the lower range is met. trim (bool, optional): Specify whether chunks should have whitespace trimmed from the beginning and end or not. If False, joining all chunks will return the original - string. Defaults to True.. + string. Defaults to True. Returns: The new text splitter @@ -149,8 +149,7 @@ class TextSplitter: up the chunk until the lower range is met. trim (bool, optional): Specify whether chunks should have whitespace trimmed from the beginning and end or not. If False, joining all chunks will return the original - string. Defaults to True.. - + string. Defaults to True. Returns: The new text splitter @@ -171,7 +170,7 @@ class TextSplitter: up the chunk until the lower range is met. trim (bool, optional): Specify whether chunks should have whitespace trimmed from the beginning and end or not. If False, joining all chunks will return the original - string. Defaults to True.. + string. Defaults to True. Returns: The new text splitter @@ -258,10 +257,10 @@ class MarkdownSplitter: # Maximum number of characters in a chunk max_characters = 1000 # Optionally can also have the splitter not trim whitespace for you - splitter = MarkdownSplitter() - # splitter = MarkdownSplitter(trim_chunks=False) + splitter = MarkdownSplitter(max_characters) + # splitter = MarkdownSplitter(max_characters, trim=False) - chunks = splitter.chunks("# Header\n\nyour document text", max_characters) + chunks = splitter.chunks("# Header\n\nyour document text") ``` ### Using a Range for Chunk Capacity @@ -275,11 +274,11 @@ class MarkdownSplitter: ```python from semantic_text_splitter import MarkdownSplitter - splitter = MarkdownSplitter() + splitter = MarkdownSplitter(capacity=(200,1000)) # Maximum number of characters in a chunk. Will fill up the # chunk until it is somewhere in this range. - chunks = splitter.chunks("# Header\n\nyour document text", chunk_capacity=(200,1000)) + chunks = splitter.chunks("# Header\n\nyour document text") ``` ### Using a Hugging Face Tokenizer @@ -291,9 +290,9 @@ class MarkdownSplitter: # Maximum number of tokens in a chunk max_tokens = 1000 tokenizer = Tokenizer.from_pretrained("bert-base-uncased") - splitter = MarkdownSplitter.from_huggingface_tokenizer(tokenizer) + splitter = MarkdownSplitter.from_huggingface_tokenizer(tokenizer, max_tokens) - chunks = splitter.chunks("# Header\n\nyour document text", max_tokens) + chunks = splitter.chunks("# Header\n\nyour document text") ``` ### Using a Tiktoken Tokenizer @@ -304,9 +303,9 @@ class MarkdownSplitter: # Maximum number of tokens in a chunk max_tokens = 1000 - splitter = MarkdownSplitter.from_tiktoken_model("gpt-3.5-turbo") + splitter = MarkdownSplitter.from_tiktoken_model("gpt-3.5-turbo"m max_tokens) - chunks = splitter.chunks("# Header\n\nyour document text", max_tokens) + chunks = splitter.chunks("# Header\n\nyour document text") ``` ### Using a Custom Callback @@ -315,36 +314,45 @@ class MarkdownSplitter: from semantic_text_splitter import MarkdownSplitter # Optionally can also have the splitter trim whitespace for you - splitter = MarkdownSplitter.from_callback(lambda text: len(text)) + splitter = MarkdownSplitter.from_callback(lambda text: len(text), 1000) # Maximum number of tokens in a chunk. Will fill up the # chunk until it is somewhere in this range. - chunks = splitter.chunks("# Header\n\nyour document text", chunk_capacity=(200,1000)) + chunks = splitter.chunks("# Header\n\nyour document text") ``` Args: - trim_chunks (bool, optional): Specify whether chunks should have whitespace trimmed from the + capacity (int | (int, int)): The capacity of characters in each chunk. If a + single int, then chunks will be filled up as much as possible, without going over + that number. If a tuple of two integers is provided, a chunk will be considered + "full" once it is within the two numbers (inclusive range). So it will only fill + up the chunk until the lower range is met. + trim (bool, optional): Specify whether chunks should have whitespace trimmed from the beginning and end or not. If False, joining all chunks will return the original - string. Indentation however will be preserved if the chunk also includes multiple lines. - Extra newlines are always removed, but if the text would include multiple indented list - items, the indentation of the first element will also be preserved. - Defaults to True. + string. Defaults to True. """ - def __init__(self, trim_chunks: bool = True) -> None: ... + def __init__( + self, capacity: Union[int, Tuple[int, int]], trim: bool = True + ) -> None: ... @staticmethod def from_huggingface_tokenizer( - tokenizer, trim_chunks: bool = True + tokenizer, capacity: Union[int, Tuple[int, int]], trim: bool = True ) -> MarkdownSplitter: """Instantiate a new markdown splitter from a Hugging Face Tokenizer instance. Args: tokenizer (Tokenizer): A `tokenizers.Tokenizer` you want to use to count tokens for each chunk. - trim_chunks (bool, optional): Specify whether chunks should have whitespace trimmed from the - beginning and end or not. If False, joining all chunks will return the original - string. Defaults to True. + capacity (int | (int, int)): The capacity of characters in each chunk. If a + single int, then chunks will be filled up as much as possible, without going over + that number. If a tuple of two integers is provided, a chunk will be considered + "full" once it is within the two numbers (inclusive range). So it will only fill + up the chunk until the lower range is met. + trim (bool, optional): Specify whether chunks should have whitespace trimmed from the + beginning and end or not. If False, joining all chunks will return the original + string. Defaults to True. Returns: The new markdown splitter @@ -352,14 +360,19 @@ class MarkdownSplitter: @staticmethod def from_huggingface_tokenizer_str( - json: str, trim_chunks: bool = True + json: str, capacity: Union[int, Tuple[int, int]], trim: bool = True ) -> MarkdownSplitter: """Instantiate a new markdown splitter from the given Hugging Face Tokenizer JSON string. Args: json (str): A valid JSON string representing a previously serialized Hugging Face Tokenizer - trim_chunks (bool, optional): Specify whether chunks should have whitespace trimmed from the + capacity (int | (int, int)): The capacity of characters in each chunk. If a + single int, then chunks will be filled up as much as possible, without going over + that number. If a tuple of two integers is provided, a chunk will be considered + "full" once it is within the two numbers (inclusive range). So it will only fill + up the chunk until the lower range is met. + trim (bool, optional): Specify whether chunks should have whitespace trimmed from the beginning and end or not. If False, joining all chunks will return the original string. Defaults to True. @@ -369,29 +382,40 @@ class MarkdownSplitter: @staticmethod def from_huggingface_tokenizer_file( - path: str, trim_chunks: bool = True + path: str, capacity: Union[int, Tuple[int, int]], trim: bool = True ) -> MarkdownSplitter: """Instantiate a new markdown splitter from the Hugging Face tokenizer file at the given path. Args: path (str): A path to a local JSON file representing a previously serialized Hugging Face tokenizer. - trim_chunks (bool, optional): Specify whether chunks should have whitespace trimmed from the + capacity (int | (int, int)): The capacity of characters in each chunk. If a + single int, then chunks will be filled up as much as possible, without going over + that number. If a tuple of two integers is provided, a chunk will be considered + "full" once it is within the two numbers (inclusive range). So it will only fill + up the chunk until the lower range is met. + trim (bool, optional): Specify whether chunks should have whitespace trimmed from the beginning and end or not. If False, joining all chunks will return the original string. Defaults to True. - Returns: The new markdown splitter """ @staticmethod - def from_tiktoken_model(model: str, trim_chunks: bool = True) -> MarkdownSplitter: + def from_tiktoken_model( + model: str, capacity: Union[int, Tuple[int, int]], trim: bool = True + ) -> MarkdownSplitter: """Instantiate a new markdown splitter based on an OpenAI Tiktoken tokenizer. Args: model (str): The OpenAI model name you want to retrieve a tokenizer for. - trim_chunks (bool, optional): Specify whether chunks should have whitespace trimmed from the + capacity (int | (int, int)): The capacity of characters in each chunk. If a + single int, then chunks will be filled up as much as possible, without going over + that number. If a tuple of two integers is provided, a chunk will be considered + "full" once it is within the two numbers (inclusive range). So it will only fill + up the chunk until the lower range is met. + trim (bool, optional): Specify whether chunks should have whitespace trimmed from the beginning and end or not. If False, joining all chunks will return the original string. Defaults to True. @@ -401,14 +425,21 @@ class MarkdownSplitter: @staticmethod def from_callback( - callback: Callable[[str], int], trim_chunks: bool = True + callback: Callable[[str], int], + capacity: Union[int, Tuple[int, int]], + trim: bool = True, ) -> MarkdownSplitter: """Instantiate a new markdown splitter based on a custom callback. Args: callback (Callable[[str], int]): A lambda or other function that can be called. It will be provided a piece of text, and it should return an integer value for the size. - trim_chunks (bool, optional): Specify whether chunks should have whitespace trimmed from the + capacity (int | (int, int)): The capacity of characters in each chunk. If a + single int, then chunks will be filled up as much as possible, without going over + that number. If a tuple of two integers is provided, a chunk will be considered + "full" once it is within the two numbers (inclusive range). So it will only fill + up the chunk until the lower range is met. + trim (bool, optional): Specify whether chunks should have whitespace trimmed from the beginning and end or not. If False, joining all chunks will return the original string. Defaults to True. @@ -416,9 +447,7 @@ class MarkdownSplitter: The new markdown splitter """ - def chunks( - self, text: str, chunk_capacity: Union[int, Tuple[int, int]] - ) -> List[str]: + def chunks(self, text: str) -> List[str]: """Generate a list of chunks from a given text. Each chunk will be up to the `chunk_capacity`. ## Method @@ -445,31 +474,19 @@ class MarkdownSplitter: Args: text (str): Text to split. - chunk_capacity (int | (int, int)): The capacity of characters in each chunk. If a - single int, then chunks will be filled up as much as possible, without going over - that number. If a tuple of two integers is provided, a chunk will be considered - "full" once it is within the two numbers (inclusive range). So it will only fill - up the chunk until the lower range is met. Returns: A list of strings, one for each chunk. If `trim_chunks` was specified in the text splitter, then each chunk will already be trimmed as well. """ - def chunk_indices( - self, text: str, chunk_capacity: Union[int, Tuple[int, int]] - ) -> List[Tuple[int, str]]: + def chunk_indices(self, text: str) -> List[Tuple[int, str]]: """Generate a list of chunks from a given text, along with their character offsets in the original text. Each chunk will be up to the `chunk_capacity`. See `chunks` for more information. Args: text (str): Text to split. - chunk_capacity (int | (int, int)): The capacity of characters in each chunk. If a - single int, then chunks will be filled up as much as possible, without going over - that number. If a tuple of two integers is provided, a chunk will be considered - "full" once it is within the two numbers (inclusive range). So it will only fill - up the chunk until the lower range is met. Returns: A list of tuples, one for each chunk. The first item will be the character offset relative diff --git a/bindings/python/src/lib.rs b/bindings/python/src/lib.rs index b7e07bac..23377214 100644 --- a/bindings/python/src/lib.rs +++ b/bindings/python/src/lib.rs @@ -143,10 +143,10 @@ from semantic_text_splitter import TextSplitter # Maximum number of characters in a chunk max_characters = 1000 # Optionally can also have the splitter not trim whitespace for you -splitter = TextSplitter() -# splitter = TextSplitter(trim_chunks=False) +splitter = TextSplitter(max_characters) +# splitter = TextSplitter(max_characters, trim=False) -chunks = splitter.chunks("your document text", max_characters) +chunks = splitter.chunks("your document text") ``` ### Using a Range for Chunk Capacity @@ -160,11 +160,12 @@ It is always possible that a chunk may be returned that is less than the `start` ```python from semantic_text_splitter import TextSplitter -splitter = TextSplitter() # Maximum number of characters in a chunk. Will fill up the # chunk until it is somewhere in this range. -chunks = splitter.chunks("your document text", chunk_capacity=(200,1000)) +splitter = TextSplitter((200,1000)) + +chunks = splitter.chunks("your document text") ``` ### Using a Hugging Face Tokenizer @@ -176,9 +177,9 @@ from tokenizers import Tokenizer # Maximum number of tokens in a chunk max_tokens = 1000 tokenizer = Tokenizer.from_pretrained("bert-base-uncased") -splitter = TextSplitter.from_huggingface_tokenizer(tokenizer) +splitter = TextSplitter.from_huggingface_tokenizer(tokenizer, max_tokens) -chunks = splitter.chunks("your document text", max_tokens) +chunks = splitter.chunks("your document text") ``` ### Using a Tiktoken Tokenizer @@ -189,9 +190,9 @@ from semantic_text_splitter import TextSplitter # Maximum number of tokens in a chunk max_tokens = 1000 -splitter = TextSplitter.from_tiktoken_model("gpt-3.5-turbo") +splitter = TextSplitter.from_tiktoken_model("gpt-3.5-turbo", max_tokens) -chunks = splitter.chunks("your document text", max_tokens) +chunks = splitter.chunks("your document text") ``` ### Using a Custom Callback @@ -199,12 +200,9 @@ chunks = splitter.chunks("your document text", max_tokens) ```python from semantic_text_splitter import TextSplitter -# Optionally can also have the splitter trim whitespace for you -splitter = TextSplitter.from_callback(lambda text: len(text)) +splitter = TextSplitter.from_callback(lambda text: len(text), 1000) -# Maximum number of tokens in a chunk. Will fill up the -# chunk until it is somewhere in this range. -chunks = splitter.chunks("your document text", chunk_capacity=(200,1000)) +chunks = splitter.chunks("your document text") ``` Args: @@ -429,11 +427,13 @@ impl PyTextSplitter { 4. [Unicode Sentence Boundaries](https://www.unicode.org/reports/tr29/#Sentence_Boundaries) 5. Ascending sequence length of newlines. (Newline is `\r\n`, `\n`, or `\r`) Each unique length of consecutive newline sequences is treated as its own semantic level. So a sequence of 2 newlines is a higher level than a sequence of 1 newline, and so on. + Splitting doesn't occur below the character level, otherwise you could get partial bytes of a char, which may not be a valid unicode str. + Args: text (str): Text to split. Returns: - A list of strings, one for each chunk. If `trim_chunks` was specified in the text + A list of strings, one for each chunk. If `trim` was specified in the text splitter, then each chunk will already be trimmed as well. */ fn chunks<'text, 'splitter: 'text>(&'splitter self, text: &'text str) -> Vec<&'text str> { @@ -451,7 +451,7 @@ impl PyTextSplitter { Returns: A list of tuples, one for each chunk. The first item will be the character offset relative to the original text. The second item is the chunk itself. - If `trim_chunks` was specified in the text splitter, then each chunk will already be + If `trim` was specified in the text splitter, then each chunk will already be trimmed as well. */ fn chunk_indices<'text, 'splitter: 'text>( @@ -468,10 +468,10 @@ impl PyTextSplitter { #[allow(clippy::large_enum_variant)] enum MarkdownSplitterOptions { - Characters(MarkdownSplitter), - CustomCallback(MarkdownSplitter), - Huggingface(MarkdownSplitter), - Tiktoken(MarkdownSplitter), + Characters(MarkdownSplitter), + CustomCallback(MarkdownSplitter), + Huggingface(MarkdownSplitter), + Tiktoken(MarkdownSplitter), } impl MarkdownSplitterOptions { @@ -479,13 +479,12 @@ impl MarkdownSplitterOptions { fn chunks<'splitter, 'text: 'splitter>( &'splitter self, text: &'text str, - chunk_capacity: PyChunkCapacity, ) -> impl Iterator + 'splitter { match self { - Self::Characters(splitter) => splitter.chunks(text, chunk_capacity), - Self::CustomCallback(splitter) => splitter.chunks(text, chunk_capacity), - Self::Huggingface(splitter) => splitter.chunks(text, chunk_capacity), - Self::Tiktoken(splitter) => splitter.chunks(text, chunk_capacity), + Self::Characters(splitter) => splitter.chunks(text), + Self::CustomCallback(splitter) => splitter.chunks(text), + Self::Huggingface(splitter) => splitter.chunks(text), + Self::Tiktoken(splitter) => splitter.chunks(text), } } @@ -493,13 +492,12 @@ impl MarkdownSplitterOptions { fn chunk_indices<'splitter, 'text: 'splitter>( &'splitter self, text: &'text str, - chunk_capacity: PyChunkCapacity, ) -> impl Iterator + 'splitter { match self { - Self::Characters(splitter) => splitter.chunk_indices(text, chunk_capacity), - Self::CustomCallback(splitter) => splitter.chunk_indices(text, chunk_capacity), - Self::Huggingface(splitter) => splitter.chunk_indices(text, chunk_capacity), - Self::Tiktoken(splitter) => splitter.chunk_indices(text, chunk_capacity), + Self::Characters(splitter) => splitter.chunk_indices(text), + Self::CustomCallback(splitter) => splitter.chunk_indices(text), + Self::Huggingface(splitter) => splitter.chunk_indices(text), + Self::Tiktoken(splitter) => splitter.chunk_indices(text), } } } @@ -580,12 +578,14 @@ chunks = splitter.chunks("# Header\n\nyour document text", chunk_capacity=(200,1 ``` Args: - trim_chunks (bool, optional): Specify whether chunks should have whitespace trimmed from the + capacity (int | (int, int)): The capacity of characters in each chunk. If a + single int, then chunks will be filled up as much as possible, without going over + that number. If a tuple of two integers is provided, a chunk will be considered + "full" once it is within the two numbers (inclusive range). So it will only fill + up the chunk until the lower range is met. + trim (bool, optional): Specify whether chunks should have whitespace trimmed from the beginning and end or not. If False, joining all chunks will return the original - string. Indentation however will be preserved if the chunk also includes multiple lines. - Extra newlines are always removed, but if the text would include multiple indented list - items, the indentation of the first element will also be preserved. - Defaults to True. + string. Defaults to True. */ #[pyclass(frozen, name = "MarkdownSplitter")] struct PyMarkdownSplitter { @@ -595,12 +595,12 @@ struct PyMarkdownSplitter { #[pymethods] impl PyMarkdownSplitter { #[new] - #[pyo3(signature = (trim_chunks=true))] - fn new(trim_chunks: bool) -> Self { + #[pyo3(signature = (capacity, trim=true))] + fn new(capacity: PyChunkCapacity, trim: bool) -> Self { Self { - splitter: MarkdownSplitterOptions::Characters( - MarkdownSplitter::default().with_trim_chunks(trim_chunks), - ), + splitter: MarkdownSplitterOptions::Characters(MarkdownSplitter::new( + ChunkConfig::new(capacity).with_trim(trim), + )), } } @@ -610,21 +610,24 @@ impl PyMarkdownSplitter { Args: tokenizer (Tokenizer): A `tokenizers.Tokenizer` you want to use to count tokens for each chunk. - trim_chunks (bool, optional): Specify whether chunks should have whitespace trimmed from the + capacity (int | (int, int)): The capacity of characters in each chunk. If a + single int, then chunks will be filled up as much as possible, without going over + that number. If a tuple of two integers is provided, a chunk will be considered + "full" once it is within the two numbers (inclusive range). So it will only fill + up the chunk until the lower range is met. + trim (bool, optional): Specify whether chunks should have whitespace trimmed from the beginning and end or not. If False, joining all chunks will return the original - string. Indentation however will be preserved if the chunk also includes multiple lines. - Extra newlines are always removed, but if the text would include multiple indented list - items, the indentation of the first element will also be preserved. - Defaults to True. + string. Defaults to True. Returns: The new markdown splitter */ #[staticmethod] - #[pyo3(signature = (tokenizer, trim_chunks=true))] + #[pyo3(signature = (tokenizer, capacity, trim=true))] fn from_huggingface_tokenizer( tokenizer: &Bound<'_, PyAny>, - trim_chunks: bool, + capacity: PyChunkCapacity, + trim: bool, ) -> PyResult { // Get the json out so we can reconstruct the tokenizer on the Rust side let json = tokenizer.call_method0("to_str")?.extract::()?; @@ -632,9 +635,11 @@ impl PyMarkdownSplitter { Tokenizer::from_str(&json).map_err(|e| PyException::new_err(format!("{e}")))?; Ok(Self { - splitter: MarkdownSplitterOptions::Huggingface( - MarkdownSplitter::new(tokenizer).with_trim_chunks(trim_chunks), - ), + splitter: MarkdownSplitterOptions::Huggingface(MarkdownSplitter::new( + ChunkConfig::new(capacity) + .with_sizer(tokenizer) + .with_trim(trim), + )), }) } @@ -644,27 +649,35 @@ impl PyMarkdownSplitter { Args: json (str): A valid JSON string representing a previously serialized Hugging Face Tokenizer - trim_chunks (bool, optional): Specify whether chunks should have whitespace trimmed from the + capacity (int | (int, int)): The capacity of characters in each chunk. If a + single int, then chunks will be filled up as much as possible, without going over + that number. If a tuple of two integers is provided, a chunk will be considered + "full" once it is within the two numbers (inclusive range). So it will only fill + up the chunk until the lower range is met. + trim (bool, optional): Specify whether chunks should have whitespace trimmed from the beginning and end or not. If False, joining all chunks will return the original - string. Indentation however will be preserved if the chunk also includes multiple lines. - Extra newlines are always removed, but if the text would include multiple indented list - items, the indentation of the first element will also be preserved. - Defaults to True. + string. Defaults to True. Returns: The new markdown splitter */ #[staticmethod] - #[pyo3(signature = (json, trim_chunks=true))] - fn from_huggingface_tokenizer_str(json: &str, trim_chunks: bool) -> PyResult { + #[pyo3(signature = (json, capacity, trim=true))] + fn from_huggingface_tokenizer_str( + json: &str, + capacity: PyChunkCapacity, + trim: bool, + ) -> PyResult { let tokenizer = json .parse() .map_err(|e| PyException::new_err(format!("{e}")))?; Ok(Self { - splitter: MarkdownSplitterOptions::Huggingface( - MarkdownSplitter::new(tokenizer).with_trim_chunks(trim_chunks), - ), + splitter: MarkdownSplitterOptions::Huggingface(MarkdownSplitter::new( + ChunkConfig::new(capacity) + .with_sizer(tokenizer) + .with_trim(trim), + )), }) } @@ -674,25 +687,33 @@ impl PyMarkdownSplitter { Args: path (str): A path to a local JSON file representing a previously serialized Hugging Face tokenizer. - trim_chunks (bool, optional): Specify whether chunks should have whitespace trimmed from the + capacity (int | (int, int)): The capacity of characters in each chunk. If a + single int, then chunks will be filled up as much as possible, without going over + that number. If a tuple of two integers is provided, a chunk will be considered + "full" once it is within the two numbers (inclusive range). So it will only fill + up the chunk until the lower range is met. + trim (bool, optional): Specify whether chunks should have whitespace trimmed from the beginning and end or not. If False, joining all chunks will return the original - string. Indentation however will be preserved if the chunk also includes multiple lines. - Extra newlines are always removed, but if the text would include multiple indented list - items, the indentation of the first element will also be preserved. - Defaults to True. + string. Defaults to True. Returns: The new markdown splitter */ #[staticmethod] - #[pyo3(signature = (path, trim_chunks=true))] - fn from_huggingface_tokenizer_file(path: &str, trim_chunks: bool) -> PyResult { + #[pyo3(signature = (path, capacity, trim=true))] + fn from_huggingface_tokenizer_file( + path: &str, + capacity: PyChunkCapacity, + trim: bool, + ) -> PyResult { let tokenizer = Tokenizer::from_file(path).map_err(|e| PyException::new_err(format!("{e}")))?; Ok(Self { - splitter: MarkdownSplitterOptions::Huggingface( - MarkdownSplitter::new(tokenizer).with_trim_chunks(trim_chunks), - ), + splitter: MarkdownSplitterOptions::Huggingface(MarkdownSplitter::new( + ChunkConfig::new(capacity) + .with_sizer(tokenizer) + .with_trim(trim), + )), }) } @@ -701,26 +722,30 @@ impl PyMarkdownSplitter { Args: model (str): The OpenAI model name you want to retrieve a tokenizer for. - trim_chunks (bool, optional): Specify whether chunks should have whitespace trimmed from the + capacity (int | (int, int)): The capacity of characters in each chunk. If a + single int, then chunks will be filled up as much as possible, without going over + that number. If a tuple of two integers is provided, a chunk will be considered + "full" once it is within the two numbers (inclusive range). So it will only fill + up the chunk until the lower range is met. + trim (bool, optional): Specify whether chunks should have whitespace trimmed from the beginning and end or not. If False, joining all chunks will return the original - string. Indentation however will be preserved if the chunk also includes multiple lines. - Extra newlines are always removed, but if the text would include multiple indented list - items, the indentation of the first element will also be preserved. - Defaults to True. + string. Defaults to True. Returns: The new markdown splitter */ #[staticmethod] - #[pyo3(signature = (model, trim_chunks=true))] - fn from_tiktoken_model(model: &str, trim_chunks: bool) -> PyResult { + #[pyo3(signature = (model, capacity, trim=true))] + fn from_tiktoken_model(model: &str, capacity: PyChunkCapacity, trim: bool) -> PyResult { let tokenizer = get_bpe_from_model(model).map_err(|e| PyException::new_err(format!("{e}")))?; Ok(Self { - splitter: MarkdownSplitterOptions::Tiktoken( - MarkdownSplitter::new(tokenizer).with_trim_chunks(trim_chunks), - ), + splitter: MarkdownSplitterOptions::Tiktoken(MarkdownSplitter::new( + ChunkConfig::new(capacity) + .with_sizer(tokenizer) + .with_trim(trim), + )), }) } @@ -730,23 +755,27 @@ impl PyMarkdownSplitter { Args: callback (Callable[[str], int]): A lambda or other function that can be called. It will be provided a piece of text, and it should return an integer value for the size. - trim_chunks (bool, optional): Specify whether chunks should have whitespace trimmed from the + capacity (int | (int, int)): The capacity of characters in each chunk. If a + single int, then chunks will be filled up as much as possible, without going over + that number. If a tuple of two integers is provided, a chunk will be considered + "full" once it is within the two numbers (inclusive range). So it will only fill + up the chunk until the lower range is met. + trim (bool, optional): Specify whether chunks should have whitespace trimmed from the beginning and end or not. If False, joining all chunks will return the original - string. Indentation however will be preserved if the chunk also includes multiple lines. - Extra newlines are always removed, but if the text would include multiple indented list - items, the indentation of the first element will also be preserved. - Defaults to True. + string. Defaults to True. Returns: The new markdown splitter */ #[staticmethod] - #[pyo3(signature = (callback, trim_chunks=true))] - fn from_callback(callback: PyObject, trim_chunks: bool) -> Self { + #[pyo3(signature = (callback, capacity, trim=true))] + fn from_callback(callback: PyObject, capacity: PyChunkCapacity, trim: bool) -> Self { Self { - splitter: MarkdownSplitterOptions::CustomCallback( - MarkdownSplitter::new(CustomCallback(callback)).with_trim_chunks(trim_chunks), - ), + splitter: MarkdownSplitterOptions::CustomCallback(MarkdownSplitter::new( + ChunkConfig::new(capacity) + .with_sizer(CustomCallback(callback)) + .with_trim(trim), + )), } } @@ -777,22 +806,13 @@ impl PyMarkdownSplitter { Args: text (str): Text to split. - chunk_capacity (int | (int, int)): The capacity of characters in each chunk. If a - single int, then chunks will be filled up as much as possible, without going over - that number. If a tuple of two integers is provided, a chunk will be considered - "full" once it is within the two numbers (inclusive range). So it will only fill - up the chunk until the lower range is met. Returns: A list of strings, one for each chunk. If `trim_chunks` was specified in the text splitter, then each chunk will already be trimmed as well. */ - fn chunks<'text, 'splitter: 'text>( - &'splitter self, - text: &'text str, - chunk_capacity: PyChunkCapacity, - ) -> Vec<&'text str> { - self.splitter.chunks(text, chunk_capacity).collect() + fn chunks<'text, 'splitter: 'text>(&'splitter self, text: &'text str) -> Vec<&'text str> { + self.splitter.chunks(text).collect() } /** @@ -802,11 +822,6 @@ impl PyMarkdownSplitter { Args: text (str): Text to split. - chunk_capacity (int | (int, int)): The capacity of characters in each chunk. If a - single int, then chunks will be filled up as much as possible, without going over - that number. If a tuple of two integers is provided, a chunk will be considered - "full" once it is within the two numbers (inclusive range). So it will only fill - up the chunk until the lower range is met. Returns: A list of tuples, one for each chunk. The first item will be the character offset relative @@ -817,11 +832,10 @@ impl PyMarkdownSplitter { fn chunk_indices<'text, 'splitter: 'text>( &'splitter self, text: &'text str, - chunk_capacity: PyChunkCapacity, ) -> Vec<(usize, &'text str)> { let mut offsets = ByteToCharOffsetTracker::new(text); self.splitter - .chunk_indices(text, chunk_capacity) + .chunk_indices(text) .map(|c| offsets.map_byte_to_char(c)) .collect() } diff --git a/bindings/python/tests/test_integration.py b/bindings/python/tests/test_integration.py index 3724fec3..6f74b924 100644 --- a/bindings/python/tests/test_integration.py +++ b/bindings/python/tests/test_integration.py @@ -97,96 +97,100 @@ def test_custom() -> None: def test_markdown_chunks() -> None: - splitter = MarkdownSplitter(trim_chunks=False) + splitter = MarkdownSplitter(4, trim=False) text = "123\n\n123" - assert splitter.chunks(text, 4) == ["123\n", "\n123"] + assert splitter.chunks(text) == ["123\n", "\n123"] def test_markdown_chunks_range() -> None: - splitter = MarkdownSplitter(trim_chunks=False) + splitter = MarkdownSplitter(capacity=(3, 4), trim=False) text = "123\n\n123" - assert splitter.chunks(text=text, chunk_capacity=(3, 4)) == [ + assert splitter.chunks(text=text) == [ "123\n", "\n123", ] def test_markdown_chunks_trim() -> None: - splitter = MarkdownSplitter() + splitter = MarkdownSplitter(capacity=4) text = "123\n\n123" - assert splitter.chunks(text=text, chunk_capacity=4) == ["123", "123"] + assert splitter.chunks(text=text) == ["123", "123"] def test_markdown_hugging_face() -> None: tokenizer = Tokenizer.from_pretrained("bert-base-uncased") - splitter = MarkdownSplitter.from_huggingface_tokenizer(tokenizer, trim_chunks=False) + splitter = MarkdownSplitter.from_huggingface_tokenizer(tokenizer, 1, trim=False) text = "123\n\n123" - assert splitter.chunks(text, 1) == ["123\n", "\n123"] + assert splitter.chunks(text) == ["123\n", "\n123"] def test_markdown_hugging_face_range() -> None: tokenizer = Tokenizer.from_pretrained("bert-base-uncased") - splitter = MarkdownSplitter.from_huggingface_tokenizer(tokenizer, trim_chunks=False) + splitter = MarkdownSplitter.from_huggingface_tokenizer( + tokenizer, capacity=(1, 2), trim=False + ) text = "123\n\n123" - assert splitter.chunks(text=text, chunk_capacity=(1, 2)) == ["123\n", "\n123"] + assert splitter.chunks(text=text) == ["123\n", "\n123"] def test_markdown_hugging_face_trim() -> None: tokenizer = Tokenizer.from_pretrained("bert-base-uncased") - splitter = MarkdownSplitter.from_huggingface_tokenizer(tokenizer) + splitter = MarkdownSplitter.from_huggingface_tokenizer(tokenizer, capacity=1) text = "123\n\n123" - assert splitter.chunks(text=text, chunk_capacity=1) == ["123", "123"] + assert splitter.chunks(text=text) == ["123", "123"] def test_markdown_hugging_face_from_str() -> None: tokenizer = Tokenizer.from_pretrained("bert-base-uncased") - splitter = MarkdownSplitter.from_huggingface_tokenizer_str(tokenizer.to_str()) + splitter = MarkdownSplitter.from_huggingface_tokenizer_str( + tokenizer.to_str(), capacity=1 + ) text = "123\n\n123" - assert splitter.chunks(text=text, chunk_capacity=1) == ["123", "123"] + assert splitter.chunks(text=text) == ["123", "123"] def test_markdown_hugging_face_from_file() -> None: splitter = MarkdownSplitter.from_huggingface_tokenizer_file( - "tests/bert-base-cased.json" + "tests/bert-base-cased.json", capacity=1 ) text = "123\n\n123" - assert splitter.chunks(text=text, chunk_capacity=1) == ["123", "123"] + assert splitter.chunks(text=text) == ["123", "123"] def test_markdown_tiktoken() -> None: splitter = MarkdownSplitter.from_tiktoken_model( - model="gpt-3.5-turbo", trim_chunks=False + model="gpt-3.5-turbo", capacity=2, trim=False ) text = "123\n\n123" - assert splitter.chunks(text, 2) == ["123\n", "\n123"] + assert splitter.chunks(text) == ["123\n", "\n123"] def test_markdown_tiktoken_range() -> None: splitter = MarkdownSplitter.from_tiktoken_model( - model="gpt-3.5-turbo", trim_chunks=False + model="gpt-3.5-turbo", capacity=(2, 3), trim=False ) text = "123\n\n123" - assert splitter.chunks(text=text, chunk_capacity=(2, 3)) == [ + assert splitter.chunks(text=text) == [ "123\n", "\n123", ] def test_markdown_tiktoken_trim() -> None: - splitter = MarkdownSplitter.from_tiktoken_model("gpt-3.5-turbo") + splitter = MarkdownSplitter.from_tiktoken_model("gpt-3.5-turbo", 1) text = "123\n\n123" - assert splitter.chunks(text=text, chunk_capacity=1) == ["123", "123"] + assert splitter.chunks(text=text) == ["123", "123"] def test_markdown_tiktoken_model_error() -> None: with pytest.raises(Exception): - MarkdownSplitter.from_tiktoken_model("random-model-name") + MarkdownSplitter.from_tiktoken_model("random-model-name", 1) def test_markdown_custom() -> None: - splitter = MarkdownSplitter.from_callback(lambda x: len(x)) + splitter = MarkdownSplitter.from_callback(lambda x: len(x), capacity=3) text = "123\n\n123" - assert splitter.chunks(text=text, chunk_capacity=3) == ["123", "123"] + assert splitter.chunks(text) == ["123", "123"] def test_char_indices() -> None: @@ -209,9 +213,9 @@ def test_char_indices_with_multibyte_character() -> None: def test_markdown_char_indices() -> None: - splitter = MarkdownSplitter() + splitter = MarkdownSplitter(capacity=4) text = "123\n456\n789" - assert splitter.chunk_indices(text=text, chunk_capacity=4) == [ + assert splitter.chunk_indices(text) == [ (0, "123"), (4, "456"), (8, "789"), @@ -219,10 +223,10 @@ def test_markdown_char_indices() -> None: def test_markdown_char_indices_with_multibyte_character() -> None: - splitter = MarkdownSplitter() + splitter = MarkdownSplitter(4) text = "12ü\n12ü\n12ü" assert len("12ü\n") == 4 - assert splitter.chunk_indices(text=text, chunk_capacity=4) == [ + assert splitter.chunk_indices(text=text) == [ (0, "12ü"), (4, "12ü"), (8, "12ü"), diff --git a/src/chunk_size.rs b/src/chunk_size.rs index 156d6543..6a1d0a8d 100644 --- a/src/chunk_size.rs +++ b/src/chunk_size.rs @@ -86,62 +86,6 @@ pub trait ChunkSizer { fn chunk_size(&self, chunk: &str, capacity: &impl ChunkCapacity) -> ChunkSize; } -/// A memoized chunk sizer that caches the size of chunks. -/// Very helpful when the same chunk is being validated multiple times, which -/// happens often, and can be expensive to compute, such as with tokenizers. -#[derive(Debug)] -pub struct MemoizedChunkSizer<'sizer, C, S> -where - C: ChunkCapacity, - S: ChunkSizer, -{ - /// Cache of chunk sizes per byte offset range - cache: AHashMap, ChunkSize>, - /// How big can each chunk be - chunk_capacity: C, - /// The sizer we are wrapping - sizer: &'sizer S, -} - -impl<'sizer, C, S> MemoizedChunkSizer<'sizer, C, S> -where - C: ChunkCapacity, - S: ChunkSizer, -{ - /// Wrap any chunk sizer for memoization - pub fn new(chunk_capacity: C, sizer: &'sizer S) -> Self { - Self { - cache: AHashMap::new(), - chunk_capacity, - sizer, - } - } - - /// Determine the size of a given chunk to use for validation, - /// returning a cached value if it exists, and storing the result if not. - pub fn chunk_size(&mut self, offset: usize, chunk: &str) -> ChunkSize { - *self - .cache - .entry(offset..(offset + chunk.len())) - .or_insert_with(|| self.sizer.chunk_size(chunk, &self.chunk_capacity)) - } - - /// Check if the chunk is within the capacity. Chunk should be trimmed if necessary beforehand. - pub fn check_capacity(&mut self, (offset, chunk): (usize, &str)) -> ChunkSize { - let mut chunk_size = self.chunk_size(offset, chunk); - if let Some(max_chunk_size_offset) = chunk_size.max_chunk_size_offset.as_mut() { - *max_chunk_size_offset += offset; - } - chunk_size - } - - /// Clear the cached values. Once we've moved the cursor, - /// we don't need to keep the old values around. - pub fn clear_cache(&mut self) { - self.cache.clear(); - } -} - /// Describes the largest valid chunk size(s) that can be generated. /// /// An `end` size is required, which is the maximum possible chunk size that @@ -358,7 +302,7 @@ where /// Very helpful when the same chunk is being validated multiple times, which /// happens often, and can be expensive to compute, such as with tokenizers. #[derive(Debug)] -pub struct MemoizedChunkSizer2<'sizer, C, S> +pub struct MemoizedChunkSizer<'sizer, C, S> where C: ChunkCapacity, S: ChunkSizer, @@ -371,7 +315,7 @@ where sizer: &'sizer S, } -impl<'sizer, C, S> MemoizedChunkSizer2<'sizer, C, S> +impl<'sizer, C, S> MemoizedChunkSizer<'sizer, C, S> where C: ChunkCapacity, S: ChunkSizer, @@ -552,7 +496,7 @@ mod tests { #[test] fn memoized_sizer_only_calculates_once_per_text() { let sizer = CountingSizer::default(); - let mut memoized_sizer = MemoizedChunkSizer::new(10, &sizer); + let mut memoized_sizer = MemoizedChunkSizer::new(&10, &sizer); let text = "1234567890"; for _ in 0..10 { memoized_sizer.chunk_size(0, text); @@ -564,7 +508,7 @@ mod tests { #[test] fn memoized_sizer_calculates_once_per_different_text() { let sizer = CountingSizer::default(); - let mut memoized_sizer = MemoizedChunkSizer::new(10, &sizer); + let mut memoized_sizer = MemoizedChunkSizer::new(&10, &sizer); let text = "1234567890"; for i in 0..10 { memoized_sizer.chunk_size(0, text.get(0..i).unwrap()); @@ -579,7 +523,7 @@ mod tests { #[test] fn can_clear_cache_on_memoized_sizer() { let sizer = CountingSizer::default(); - let mut memoized_sizer = MemoizedChunkSizer::new(10, &sizer); + let mut memoized_sizer = MemoizedChunkSizer::new(&10, &sizer); let text = "1234567890"; for _ in 0..10 { memoized_sizer.chunk_size(0, text); diff --git a/src/lib.rs b/src/lib.rs index bf0e057e..e50aeee9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,7 +3,7 @@ use std::{cmp::Ordering, ops::Range}; -use chunk_size::{MemoizedChunkSizer, MemoizedChunkSizer2}; +use chunk_size::MemoizedChunkSizer; use itertools::Itertools; mod chunk_size; @@ -155,242 +155,6 @@ where } impl<'sizer, 'text: 'sizer, C, S, L> TextChunks<'text, 'sizer, C, S, L> -where - C: ChunkCapacity, - S: ChunkSizer, - L: Copy + Ord + PartialOrd + 'static, - SemanticSplitRanges: SemanticSplit, -{ - /// Generate new [`TextChunks`] iterator for a given text. - /// Starts with an offset of 0 - fn new(chunk_capacity: C, chunk_sizer: &'sizer S, text: &'text str, trim_chunks: bool) -> Self { - Self { - cursor: 0, - chunk_sizer: MemoizedChunkSizer::new(chunk_capacity, chunk_sizer), - next_sections: Vec::new(), - semantic_split: SemanticSplitRanges::::new(text), - text, - trim_chunks, - } - } - - /// If trim chunks is on, trim the str and adjust the offset - fn trim_chunk(&self, offset: usize, chunk: &'text str) -> (usize, &'text str) { - if self.trim_chunks { - self.semantic_split.trim_chunk(offset, chunk) - } else { - (offset, chunk) - } - } - - /// Generate the next chunk, applying trimming settings. - /// Returns final byte offset and str. - /// Will return `None` if given an invalid range. - fn next_chunk(&mut self) -> Option<(usize, &'text str)> { - // Reset caches so we can reuse the memory allocation - self.chunk_sizer.clear_cache(); - self.semantic_split.update_ranges(self.cursor); - self.update_next_sections(); - - let start = self.cursor; - let mut end = self.cursor; - let mut equals_found = false; - let mut low = 0; - let mut high = self.next_sections.len().saturating_sub(1); - let mut successful_index = None; - let mut successful_chunk_size = None; - - while low <= high { - let mid = low + (high - low) / 2; - let (offset, str) = self.next_sections[mid]; - let text_end = offset + str.len(); - let chunk = self.text.get(start..text_end)?; - let chunk_size = self - .chunk_sizer - .check_capacity(self.trim_chunk(start, chunk)); - - match chunk_size.fits() { - Ordering::Less => { - // We got further than the last one, so update end - if text_end > end { - end = text_end; - successful_index = Some(mid); - successful_chunk_size = Some(chunk_size); - } - } - Ordering::Equal => { - // If we found a smaller equals use it. Or if this is the first equals we found - if text_end < end || !equals_found { - end = text_end; - successful_index = Some(mid); - successful_chunk_size = Some(chunk_size); - } - equals_found = true; - } - Ordering::Greater => { - // If we're too big on our smallest run, we must return at least one section - if mid == 0 && start == end { - end = text_end; - successful_index = Some(mid); - successful_chunk_size = Some(chunk_size); - } - } - }; - - // Adjust search area - if chunk_size.fits().is_lt() { - low = mid + 1; - } else if mid > 0 { - high = mid - 1; - } else { - // Nothing to adjust - break; - } - } - - // Sometimes with tokenization, we can get a bigger chunk for the same amount of tokens. - if let (Some(successful_index), Some(chunk_size)) = - (successful_index, successful_chunk_size) - { - let mut range = successful_index..self.next_sections.len(); - // We've already checked the successful index - range.next(); - - for index in range { - let (offset, str) = self.next_sections[index]; - let text_end = offset + str.len(); - let chunk = self.text.get(start..text_end)?; - let size = self - .chunk_sizer - .check_capacity(self.trim_chunk(start, chunk)); - if size.size() <= chunk_size.size() { - if text_end > end { - end = text_end; - } - } else { - break; - } - } - } - - self.cursor = end; - - let chunk = self.text.get(start..self.cursor)?; - - // Trim whitespace if user requested it - Some(self.trim_chunk(start, chunk)) - } - - /// Find the ideal next sections, breaking it up until we find the largest chunk. - /// Increasing length of chunk until we find biggest size to minimize validation time - /// on huge chunks - fn update_next_sections(&mut self) { - // First thing, clear out the list, but reuse the allocated memory - self.next_sections.clear(); - // Get starting level - let mut levels_in_remaining_text = - self.semantic_split.levels_in_remaining_text(self.cursor); - let mut semantic_level = levels_in_remaining_text - .next() - .expect("Need at least one level to progress"); - // If we aren't at the highest semantic level, stop iterating sections that go beyond the range of the next level. - let mut max_encoded_offset = None; - - let remaining_text = self.text.get(self.cursor..).unwrap(); - - let levels_with_chunks = levels_in_remaining_text - .filter_map(|level| { - self.semantic_split - .semantic_chunks(self.cursor, remaining_text, level) - .next() - .map(|(_, str)| (level, str)) - }) - // We assume that larger levels are also longer. We can skip lower levels if going to a higher level would result in a shorter text - .coalesce(|(a_level, a_str), (b_level, b_str)| { - if a_str.len() >= b_str.len() { - Ok((b_level, b_str)) - } else { - Err(((a_level, a_str), (b_level, b_str))) - } - }); - for (level, str) in levels_with_chunks { - let chunk_size = self - .chunk_sizer - .check_capacity(self.trim_chunk(self.cursor, str)); - // If this no longer fits, we use the level we are at. - if chunk_size.fits().is_gt() { - max_encoded_offset = chunk_size.max_chunk_size_offset(); - break; - } - // Otherwise break up the text with the next level - semantic_level = level; - } - - let sections = self - .semantic_split - .semantic_chunks(self.cursor, remaining_text, semantic_level) - // We don't want to return items at this level that go beyond the next highest semantic level, as that is most - // likely a meaningful breakpoint we want to preserve. We already know that the next highest doesn't fit anyway, - // so we should be safe to break once we reach it. - .take_while_inclusive(move |(offset, _)| { - max_encoded_offset.map_or(true, |max| offset <= &max) - }) - .filter(|(_, str)| !str.is_empty()); - - self.next_sections.extend(sections); - } -} - -impl<'sizer, 'text: 'sizer, C, S, L> Iterator for TextChunks<'text, 'sizer, C, S, L> -where - C: ChunkCapacity, - S: ChunkSizer, - L: Copy + Ord + PartialOrd + 'static, - SemanticSplitRanges: SemanticSplit, -{ - type Item = (usize, &'text str); - - fn next(&mut self) -> Option { - loop { - // Make sure we haven't reached the end - if self.cursor >= self.text.len() { - return None; - } - - match self.next_chunk()? { - // Make sure we didn't get an empty chunk. Should only happen in - // cases where we trim. - (_, "") => continue, - c => return Some(c), - } - } - } -} - -/// Returns chunks of text with their byte offsets as an iterator. -#[derive(Debug)] -struct TextChunks2<'text, 'sizer, C, S, L> -where - C: ChunkCapacity, - S: ChunkSizer, - L: Copy + Ord + PartialOrd + 'static, - SemanticSplitRanges: SemanticSplit, -{ - /// How to validate chunk sizes - chunk_sizer: MemoizedChunkSizer2<'sizer, C, S>, - /// Current byte offset in the `text` - cursor: usize, - /// Reusable container for next sections to avoid extra allocations - next_sections: Vec<(usize, &'text str)>, - /// Splitter used for determining semantic levels. - semantic_split: SemanticSplitRanges, - /// Original text to iterate over and generate chunks from - text: &'text str, - /// Whether or not chunks should be trimmed - trim_chunks: bool, -} - -impl<'sizer, 'text: 'sizer, C, S, L> TextChunks2<'text, 'sizer, C, S, L> where C: ChunkCapacity, S: ChunkSizer, @@ -402,7 +166,7 @@ where fn new(chunk_config: &'sizer ChunkConfig, text: &'text str) -> Self { Self { cursor: 0, - chunk_sizer: MemoizedChunkSizer2::new(chunk_config.capacity(), chunk_config.sizer()), + chunk_sizer: MemoizedChunkSizer::new(chunk_config.capacity(), chunk_config.sizer()), next_sections: Vec::new(), semantic_split: SemanticSplitRanges::::new(text), text, @@ -577,7 +341,7 @@ where } } -impl<'sizer, 'text: 'sizer, C, S, L> Iterator for TextChunks2<'text, 'sizer, C, S, L> +impl<'sizer, 'text: 'sizer, C, S, L> Iterator for TextChunks<'text, 'sizer, C, S, L> where C: ChunkCapacity, S: ChunkSizer, diff --git a/src/markdown.rs b/src/markdown.rs index eebb244d..e736632c 100644 --- a/src/markdown.rs +++ b/src/markdown.rs @@ -13,7 +13,7 @@ use pulldown_cmark::{Event, Options, Parser, Tag}; use unicode_segmentation::UnicodeSegmentation; use crate::{ - Characters, ChunkCapacity, ChunkSizer, SemanticSplit, SemanticSplitRanges, TextChunks, + ChunkCapacity, ChunkConfig, ChunkSizer, SemanticSplit, SemanticSplitRanges, TextChunks, }; /// Markdown splitter. Recursively splits chunks into the largest @@ -22,65 +22,35 @@ use crate::{ /// given chunk size. #[derive(Debug)] #[allow(clippy::module_name_repetitions)] -pub struct MarkdownSplitter +pub struct MarkdownSplitter where - S: ChunkSizer, + Capacity: ChunkCapacity, + Sizer: ChunkSizer, { /// Method of determining chunk sizes. - chunk_sizer: S, - /// Whether or not all chunks should have whitespace trimmed. - /// If `false`, joining all chunks should return the original string. - /// If `true`, all chunks will have whitespace removed from beginning and end, preserving indentation if necessary. - trim_chunks: bool, + chunk_config: ChunkConfig, } -impl Default for MarkdownSplitter { - fn default() -> Self { - Self::new(Characters) - } -} - -impl MarkdownSplitter +impl MarkdownSplitter where - S: ChunkSizer, + Capacity: ChunkCapacity, + Sizer: ChunkSizer, { /// Creates a new [`MarkdownSplitter`]. /// /// ``` - /// use text_splitter::{Characters, MarkdownSplitter}; + /// use text_splitter::MarkdownSplitter; /// - /// // Characters is the default, so you can also do `MarkdownSplitter::default()` - /// let splitter = MarkdownSplitter::new(Characters); + /// // By default, the chunk sizer is based on characters. + /// let splitter = MarkdownSplitter::new(512); /// ``` #[must_use] - pub fn new(chunk_sizer: S) -> Self { + pub fn new(chunk_config: impl Into>) -> Self { Self { - chunk_sizer, - trim_chunks: false, + chunk_config: chunk_config.into(), } } - /// Specify whether chunks should have whitespace trimmed from the - /// beginning and end or not. - /// - /// If `false` (default), joining all chunks should return the original - /// string. - /// If `true`, all chunks will have whitespace removed from beginning and end. - /// Indentation however will be preserved if the chunk also includes multiple lines. - /// Extra newlines are always removed, but if the text would include multiple indented list - /// items, the indentation of the first element will also be preserved. - /// - /// ``` - /// use text_splitter::{Characters, MarkdownSplitter}; - /// - /// let splitter = MarkdownSplitter::default().with_trim_chunks(true); - /// ``` - #[must_use] - pub fn with_trim_chunks(mut self, trim_chunks: bool) -> Self { - self.trim_chunks = trim_chunks; - self - } - /// Generate a list of chunks from a given text. Each chunk will be up to /// the `max_chunk_size`. /// @@ -103,20 +73,19 @@ where /// Markdown is parsed according to the Commonmark spec, along with some optional features such as GitHub Flavored Markdown. /// /// ``` - /// use text_splitter::{Characters, MarkdownSplitter}; + /// use text_splitter::MarkdownSplitter; /// - /// let splitter = MarkdownSplitter::default(); + /// let splitter = MarkdownSplitter::new(10); /// let text = "# Header\n\nfrom a\ndocument"; - /// let chunks = splitter.chunks(text, 10).collect::>(); + /// let chunks = splitter.chunks(text).collect::>(); /// - /// assert_eq!(vec!["# Header\n\n", "from a\n", "document"], chunks); + /// assert_eq!(vec!["# Header", "from a", "document"], chunks); /// ``` pub fn chunks<'splitter, 'text: 'splitter>( &'splitter self, text: &'text str, - chunk_capacity: impl ChunkCapacity + 'splitter, ) -> impl Iterator + 'splitter { - self.chunk_indices(text, chunk_capacity).map(|(_, t)| t) + self.chunk_indices(text).map(|(_, t)| t) } /// Returns an iterator over chunks of the text and their byte offsets. @@ -125,24 +94,18 @@ where /// See [`MarkdownSplitter::chunks`] for more information. /// /// ``` - /// use text_splitter::{Characters, MarkdownSplitter}; + /// use text_splitter::MarkdownSplitter; /// - /// let splitter = MarkdownSplitter::default(); + /// let splitter = MarkdownSplitter::new(10); /// let text = "# Header\n\nfrom a\ndocument"; - /// let chunks = splitter.chunk_indices(text, 10).collect::>(); + /// let chunks = splitter.chunk_indices(text).collect::>(); /// - /// assert_eq!(vec![(0, "# Header\n\n"), (10, "from a\n"), (17, "document")], chunks); + /// assert_eq!(vec![(0, "# Header"), (10, "from a"), (17, "document")], chunks); pub fn chunk_indices<'splitter, 'text: 'splitter>( &'splitter self, text: &'text str, - chunk_capacity: impl ChunkCapacity + 'splitter, ) -> impl Iterator + 'splitter { - TextChunks::<_, S, SemanticLevel>::new( - chunk_capacity, - &self.chunk_sizer, - text, - self.trim_chunks, - ) + TextChunks::<_, Sizer, SemanticLevel>::new(&self.chunk_config, text) } } @@ -429,10 +392,12 @@ mod tests { #[test] fn returns_one_chunk_if_text_is_shorter_than_max_chunk_size() { let text = Faker.fake::(); - let chunks = - TextChunks::<_, _, SemanticLevel>::new(text.chars().count(), &Characters, &text, false) - .map(|(_, c)| c) - .collect::>(); + let chunks = TextChunks::<_, _, SemanticLevel>::new( + &ChunkConfig::new(text.chars().count()).with_trim(false), + &text, + ) + .map(|(_, c)| c) + .collect::>(); assert_eq!(vec![&text], chunks); } @@ -444,10 +409,12 @@ mod tests { // Round up to one above half so it goes to 2 chunks let max_chunk_size = text.chars().count() / 2 + 1; - let chunks = - TextChunks::<_, _, SemanticLevel>::new(max_chunk_size, &Characters, &text, false) - .map(|(_, c)| c) - .collect::>(); + let chunks = TextChunks::<_, _, SemanticLevel>::new( + &ChunkConfig::new(max_chunk_size).with_trim(false), + &text, + ) + .map(|(_, c)| c) + .collect::>(); assert!(chunks.iter().all(|c| c.chars().count() <= max_chunk_size)); @@ -467,18 +434,20 @@ mod tests { #[test] fn empty_string() { let text = ""; - let chunks = TextChunks::<_, _, SemanticLevel>::new(100, &Characters, text, false) - .map(|(_, c)| c) - .collect::>(); + let chunks = + TextChunks::<_, _, SemanticLevel>::new(&ChunkConfig::new(100).with_trim(false), text) + .map(|(_, c)| c) + .collect::>(); assert!(chunks.is_empty()); } #[test] fn can_handle_unicode_characters() { let text = "éé"; // Char that is more than one byte - let chunks = TextChunks::<_, _, SemanticLevel>::new(1, &Characters, text, false) - .map(|(_, c)| c) - .collect::>(); + let chunks = + TextChunks::<_, _, SemanticLevel>::new(&ChunkConfig::new(1).with_trim(false), text) + .map(|(_, c)| c) + .collect::>(); assert_eq!(vec!["é", "é"], chunks); } @@ -486,9 +455,10 @@ mod tests { fn chunk_by_graphemes() { let text = "a̐éö̲\r\n"; - let chunks = TextChunks::<_, _, SemanticLevel>::new(3, &Characters, text, false) - .map(|(_, g)| g) - .collect::>(); + let chunks = + TextChunks::<_, _, SemanticLevel>::new(&ChunkConfig::new(3).with_trim(false), text) + .map(|(_, g)| g) + .collect::>(); // \r\n is grouped together not separated assert_eq!(vec!["a̐é", "ö̲", "\r\n"], chunks); } @@ -498,7 +468,7 @@ mod tests { let text = " a b "; let chunks = - TextChunks::<_, _, SemanticLevel>::new(1, &Characters, text, true).collect::>(); + TextChunks::<_, _, SemanticLevel>::new(&ChunkConfig::new(1), text).collect::>(); assert_eq!(vec![(1, "a"), (3, "b")], chunks); } @@ -506,9 +476,10 @@ mod tests { fn graphemes_fallback_to_chars() { let text = "a̐éö̲\r\n"; - let chunks = TextChunks::<_, _, SemanticLevel>::new(1, &Characters, text, false) - .map(|(_, g)| g) - .collect::>(); + let chunks = + TextChunks::<_, _, SemanticLevel>::new(&ChunkConfig::new(1).with_trim(false), text) + .map(|(_, g)| g) + .collect::>(); assert_eq!( vec!["a", "\u{310}", "é", "ö", "\u{332}", "\r", "\n"], chunks @@ -520,7 +491,7 @@ mod tests { let text = "\r\na̐éö̲\r\n"; let chunks = - TextChunks::<_, _, SemanticLevel>::new(3, &Characters, text, true).collect::>(); + TextChunks::<_, _, SemanticLevel>::new(&ChunkConfig::new(3), text).collect::>(); assert_eq!(vec![(2, "a̐é"), (7, "ö̲")], chunks); } @@ -528,9 +499,10 @@ mod tests { fn chunk_by_words() { let text = "The quick brown fox can jump 32.3 feet, right?"; - let chunks = TextChunks::<_, _, SemanticLevel>::new(10, &Characters, text, false) - .map(|(_, w)| w) - .collect::>(); + let chunks = + TextChunks::<_, _, SemanticLevel>::new(&ChunkConfig::new(10).with_trim(false), text) + .map(|(_, w)| w) + .collect::>(); assert_eq!( vec![ "The quick ", @@ -546,9 +518,10 @@ mod tests { #[test] fn words_fallback_to_graphemes() { let text = "Thé quick\r\n"; - let chunks = TextChunks::<_, _, SemanticLevel>::new(2, &Characters, text, false) - .map(|(_, w)| w) - .collect::>(); + let chunks = + TextChunks::<_, _, SemanticLevel>::new(&ChunkConfig::new(2).with_trim(false), text) + .map(|(_, w)| w) + .collect::>(); assert_eq!(vec!["Th", "é ", "qu", "ic", "k", "\r\n"], chunks); } @@ -556,7 +529,7 @@ mod tests { fn trim_word_indices() { let text = "Some text from a document"; let chunks = - TextChunks::<_, _, SemanticLevel>::new(10, &Characters, text, true).collect::>(); + TextChunks::<_, _, SemanticLevel>::new(&ChunkConfig::new(10), text).collect::>(); assert_eq!( vec![(0, "Some text"), (10, "from a"), (17, "document")], chunks @@ -566,18 +539,20 @@ mod tests { #[test] fn chunk_by_sentences() { let text = "Mr. Fox jumped. The dog was too lazy."; - let chunks = TextChunks::<_, _, SemanticLevel>::new(21, &Characters, text, false) - .map(|(_, s)| s) - .collect::>(); + let chunks = + TextChunks::<_, _, SemanticLevel>::new(&ChunkConfig::new(21).with_trim(false), text) + .map(|(_, s)| s) + .collect::>(); assert_eq!(vec!["Mr. Fox jumped. ", "The dog was too lazy."], chunks); } #[test] fn sentences_falls_back_to_words() { let text = "Mr. Fox jumped. The dog was too lazy."; - let chunks = TextChunks::<_, _, SemanticLevel>::new(16, &Characters, text, false) - .map(|(_, s)| s) - .collect::>(); + let chunks = + TextChunks::<_, _, SemanticLevel>::new(&ChunkConfig::new(16).with_trim(false), text) + .map(|(_, s)| s) + .collect::>(); assert_eq!( vec!["Mr. Fox jumped. ", "The dog was too ", "lazy."], chunks @@ -588,7 +563,7 @@ mod tests { fn trim_sentence_indices() { let text = "Some text. From a document."; let chunks = - TextChunks::<_, _, SemanticLevel>::new(10, &Characters, text, true).collect::>(); + TextChunks::<_, _, SemanticLevel>::new(&ChunkConfig::new(10), text).collect::>(); assert_eq!( vec![(0, "Some text."), (11, "From a"), (18, "document.")], chunks diff --git a/src/text.rs b/src/text.rs index 1709e0ea..9050ed46 100644 --- a/src/text.rs +++ b/src/text.rs @@ -13,7 +13,7 @@ use regex::Regex; use unicode_segmentation::UnicodeSegmentation; use crate::{ - ChunkCapacity, ChunkConfig, ChunkSizer, SemanticSplit, SemanticSplitRanges, TextChunks2, + ChunkCapacity, ChunkConfig, ChunkSizer, SemanticSplit, SemanticSplitRanges, TextChunks, }; /// Default plain-text splitter. Recursively splits chunks into the largest @@ -105,7 +105,7 @@ where &'splitter self, text: &'text str, ) -> impl Iterator + 'splitter { - TextChunks2::<_, Sizer, SemanticLevel>::new(&self.chunk_config, text) + TextChunks::<_, Sizer, SemanticLevel>::new(&self.chunk_config, text) } } @@ -254,14 +254,14 @@ mod tests { use fake::{Fake, Faker}; - use crate::{ChunkSize, TextChunks2}; + use crate::{ChunkSize, TextChunks}; use super::*; #[test] fn returns_one_chunk_if_text_is_shorter_than_max_chunk_size() { let text = Faker.fake::(); - let chunks = TextChunks2::<_, _, SemanticLevel>::new( + let chunks = TextChunks::<_, _, SemanticLevel>::new( &ChunkConfig::new(text.chars().count()).with_trim(false), &text, ) @@ -278,7 +278,7 @@ mod tests { // Round up to one above half so it goes to 2 chunks let max_chunk_size = text.chars().count() / 2 + 1; - let chunks = TextChunks2::<_, _, SemanticLevel>::new( + let chunks = TextChunks::<_, _, SemanticLevel>::new( &ChunkConfig::new(max_chunk_size).with_trim(false), &text, ) @@ -304,7 +304,7 @@ mod tests { fn empty_string() { let text = ""; let chunks = - TextChunks2::<_, _, SemanticLevel>::new(&ChunkConfig::new(100).with_trim(false), text) + TextChunks::<_, _, SemanticLevel>::new(&ChunkConfig::new(100).with_trim(false), text) .map(|(_, c)| c) .collect::>(); assert!(chunks.is_empty()); @@ -314,7 +314,7 @@ mod tests { fn can_handle_unicode_characters() { let text = "éé"; // Char that is more than one byte let chunks = - TextChunks2::<_, _, SemanticLevel>::new(&ChunkConfig::new(1).with_trim(false), text) + TextChunks::<_, _, SemanticLevel>::new(&ChunkConfig::new(1).with_trim(false), text) .map(|(_, c)| c) .collect::>(); assert_eq!(vec!["é", "é"], chunks); @@ -335,7 +335,7 @@ mod tests { #[test] fn custom_len_function() { let text = "éé"; // Char that is two bytes each - let chunks = TextChunks2::<_, _, SemanticLevel>::new( + let chunks = TextChunks::<_, _, SemanticLevel>::new( &ChunkConfig::new(2).with_sizer(Str).with_trim(false), text, ) @@ -347,7 +347,7 @@ mod tests { #[test] fn handles_char_bigger_than_len() { let text = "éé"; // Char that is two bytes each - let chunks = TextChunks2::<_, _, SemanticLevel>::new( + let chunks = TextChunks::<_, _, SemanticLevel>::new( &ChunkConfig::new(1).with_sizer(Str).with_trim(false), text, ) @@ -362,7 +362,7 @@ mod tests { let text = "a̐éö̲\r\n"; let chunks = - TextChunks2::<_, _, SemanticLevel>::new(&ChunkConfig::new(3).with_trim(false), text) + TextChunks::<_, _, SemanticLevel>::new(&ChunkConfig::new(3).with_trim(false), text) .map(|(_, g)| g) .collect::>(); // \r\n is grouped together not separated @@ -374,7 +374,7 @@ mod tests { let text = " a b "; let chunks = - TextChunks2::<_, _, SemanticLevel>::new(&ChunkConfig::new(1), text).collect::>(); + TextChunks::<_, _, SemanticLevel>::new(&ChunkConfig::new(1), text).collect::>(); assert_eq!(vec![(1, "a"), (3, "b")], chunks); } @@ -383,7 +383,7 @@ mod tests { let text = "a̐éö̲\r\n"; let chunks = - TextChunks2::<_, _, SemanticLevel>::new(&ChunkConfig::new(1).with_trim(false), text) + TextChunks::<_, _, SemanticLevel>::new(&ChunkConfig::new(1).with_trim(false), text) .map(|(_, g)| g) .collect::>(); assert_eq!( @@ -397,7 +397,7 @@ mod tests { let text = "\r\na̐éö̲\r\n"; let chunks = - TextChunks2::<_, _, SemanticLevel>::new(&ChunkConfig::new(3), text).collect::>(); + TextChunks::<_, _, SemanticLevel>::new(&ChunkConfig::new(3), text).collect::>(); assert_eq!(vec![(2, "a̐é"), (7, "ö̲")], chunks); } @@ -406,7 +406,7 @@ mod tests { let text = "The quick (\"brown\") fox can't jump 32.3 feet, right?"; let chunks = - TextChunks2::<_, _, SemanticLevel>::new(&ChunkConfig::new(10).with_trim(false), text) + TextChunks::<_, _, SemanticLevel>::new(&ChunkConfig::new(10).with_trim(false), text) .map(|(_, w)| w) .collect::>(); assert_eq!( @@ -426,7 +426,7 @@ mod tests { fn words_fallback_to_graphemes() { let text = "Thé quick\r\n"; let chunks = - TextChunks2::<_, _, SemanticLevel>::new(&ChunkConfig::new(2).with_trim(false), text) + TextChunks::<_, _, SemanticLevel>::new(&ChunkConfig::new(2).with_trim(false), text) .map(|(_, w)| w) .collect::>(); assert_eq!(vec!["Th", "é ", "qu", "ic", "k", "\r\n"], chunks); @@ -435,8 +435,8 @@ mod tests { #[test] fn trim_word_indices() { let text = "Some text from a document"; - let chunks = TextChunks2::<_, _, SemanticLevel>::new(&ChunkConfig::new(10), text) - .collect::>(); + let chunks = + TextChunks::<_, _, SemanticLevel>::new(&ChunkConfig::new(10), text).collect::>(); assert_eq!( vec![(0, "Some text"), (10, "from a"), (17, "document")], chunks @@ -447,7 +447,7 @@ mod tests { fn chunk_by_sentences() { let text = "Mr. Fox jumped. [...] The dog was too lazy."; let chunks = - TextChunks2::<_, _, SemanticLevel>::new(&ChunkConfig::new(21).with_trim(false), text) + TextChunks::<_, _, SemanticLevel>::new(&ChunkConfig::new(21).with_trim(false), text) .map(|(_, s)| s) .collect::>(); assert_eq!( @@ -460,7 +460,7 @@ mod tests { fn sentences_falls_back_to_words() { let text = "Mr. Fox jumped. [...] The dog was too lazy."; let chunks = - TextChunks2::<_, _, SemanticLevel>::new(&ChunkConfig::new(16).with_trim(false), text) + TextChunks::<_, _, SemanticLevel>::new(&ChunkConfig::new(16).with_trim(false), text) .map(|(_, s)| s) .collect::>(); assert_eq!( @@ -472,8 +472,8 @@ mod tests { #[test] fn trim_sentence_indices() { let text = "Some text. From a document."; - let chunks = TextChunks2::<_, _, SemanticLevel>::new(&ChunkConfig::new(10), text) - .collect::>(); + let chunks = + TextChunks::<_, _, SemanticLevel>::new(&ChunkConfig::new(10), text).collect::>(); assert_eq!( vec![(0, "Some text."), (11, "From a"), (18, "document.")], chunks @@ -483,8 +483,8 @@ mod tests { #[test] fn trim_paragraph_indices() { let text = "Some text\n\nfrom a\ndocument"; - let chunks = TextChunks2::<_, _, SemanticLevel>::new(&ChunkConfig::new(10), text) - .collect::>(); + let chunks = + TextChunks::<_, _, SemanticLevel>::new(&ChunkConfig::new(10), text).collect::>(); assert_eq!( vec![(0, "Some text"), (11, "from a"), (18, "document")], chunks diff --git a/tests/markdown.rs b/tests/markdown.rs index 6e495dcd..6670275a 100644 --- a/tests/markdown.rs +++ b/tests/markdown.rs @@ -4,7 +4,7 @@ use fake::{Fake, Faker}; use itertools::Itertools; use more_asserts::assert_le; #[cfg(feature = "markdown")] -use text_splitter::MarkdownSplitter; +use text_splitter::{ChunkConfig, MarkdownSplitter}; #[cfg(feature = "markdown")] #[test] @@ -13,8 +13,8 @@ fn random_chunk_size() { for _ in 0..10 { let max_characters = Faker.fake(); - let splitter = MarkdownSplitter::default(); - let chunks = splitter.chunks(&text, max_characters).collect::>(); + let splitter = MarkdownSplitter::new(ChunkConfig::new(max_characters).with_trim(false)); + let chunks = splitter.chunks(&text).collect::>(); assert_eq!(chunks.join(""), text); for chunk in chunks { @@ -30,10 +30,8 @@ fn random_chunk_indices_increase() { for _ in 0..10 { let max_characters = Faker.fake::(); - let splitter = MarkdownSplitter::default(); - let indices = splitter - .chunk_indices(&text, max_characters) - .map(|(i, _)| i); + let splitter = MarkdownSplitter::new(ChunkConfig::new(max_characters).with_trim(false)); + let indices = splitter.chunk_indices(&text).map(|(i, _)| i); assert!(indices.tuple_windows().all(|(a, b)| a < b)); } @@ -42,10 +40,10 @@ fn random_chunk_indices_increase() { #[cfg(feature = "markdown")] #[test] fn fallsback_to_normal_text_split_if_no_markdown_content() { - let splitter = MarkdownSplitter::default(); + let chunk_config = ChunkConfig::new(10).with_trim(false); + let splitter = MarkdownSplitter::new(chunk_config); let text = "Some text\n\nfrom a\ndocument"; - let chunk_size = 10; - let chunks = splitter.chunks(text, chunk_size).collect::>(); + let chunks = splitter.chunks(text).collect::>(); assert_eq!(["Some text\n", "\nfrom a\n", "document"].to_vec(), chunks); } @@ -53,10 +51,9 @@ fn fallsback_to_normal_text_split_if_no_markdown_content() { #[cfg(feature = "markdown")] #[test] fn split_by_rule() { - let splitter = MarkdownSplitter::default(); + let splitter = MarkdownSplitter::new(ChunkConfig::new(12).with_trim(false)); let text = "Some text\n\n---\n\nwith a rule"; - let chunk_size = 12; - let chunks = splitter.chunks(text, chunk_size).collect::>(); + let chunks = splitter.chunks(text).collect::>(); assert_eq!(["Some text\n\n", "---\n", "\nwith a rule"].to_vec(), chunks); } @@ -64,10 +61,9 @@ fn split_by_rule() { #[cfg(feature = "markdown")] #[test] fn split_by_rule_trim() { - let splitter = MarkdownSplitter::default().with_trim_chunks(true); + let splitter = MarkdownSplitter::new(12); let text = "Some text\n\n---\n\nwith a rule"; - let chunk_size = 12; - let chunks = splitter.chunks(text, chunk_size).collect::>(); + let chunks = splitter.chunks(text).collect::>(); assert_eq!(["Some text", "---", "with a rule"].to_vec(), chunks); } @@ -75,10 +71,9 @@ fn split_by_rule_trim() { #[cfg(feature = "markdown")] #[test] fn split_by_headers() { - let splitter = MarkdownSplitter::default(); + let splitter = MarkdownSplitter::new(ChunkConfig::new(30).with_trim(false)); let text = "# Header 1\n\nSome text\n\n## Header 2\n\nwith headings\n"; - let chunk_size = 30; - let chunks = splitter.chunks(text, chunk_size).collect::>(); + let chunks = splitter.chunks(text).collect::>(); assert_eq!( [ @@ -93,10 +88,9 @@ fn split_by_headers() { #[cfg(feature = "markdown")] #[test] fn subheadings_grouped_with_top_header() { - let splitter = MarkdownSplitter::default(); + let splitter = MarkdownSplitter::new(ChunkConfig::new(60).with_trim(false)); let text = "# Header 1\n\nSome text\n\n## Header 2\n\nwith headings\n\n### Subheading\n\nand more text\n"; - let chunk_size = 60; - let chunks = splitter.chunks(text, chunk_size).collect::>(); + let chunks = splitter.chunks(text).collect::>(); assert_eq!( [ @@ -111,10 +105,9 @@ fn subheadings_grouped_with_top_header() { #[cfg(feature = "markdown")] #[test] fn trimming_doesnt_trim_block_level_indentation_if_multiple_items() { - let splitter = MarkdownSplitter::default().with_trim_chunks(true); + let splitter = MarkdownSplitter::new(48); let text = "* Really long list item that is too big to fit\n\n * Some Indented Text\n\n * More Indented Text\n\n"; - let chunk_size = 48; - let chunks = splitter.chunks(text, chunk_size).collect::>(); + let chunks = splitter.chunks(text).collect::>(); assert_eq!( [ @@ -129,10 +122,9 @@ fn trimming_doesnt_trim_block_level_indentation_if_multiple_items() { #[cfg(feature = "markdown")] #[test] fn trimming_does_trim_block_level_indentation_if_only_one_item() { - let splitter = MarkdownSplitter::default().with_trim_chunks(true); + let splitter = MarkdownSplitter::new(30); let text = "1. Really long list item\n\n 1. Some Indented Text\n\n 2. More Indented Text\n\n"; - let chunk_size = 30; - let chunks = splitter.chunks(text, chunk_size).collect::>(); + let chunks = splitter.chunks(text).collect::>(); assert_eq!( [ diff --git a/tests/text_splitter_snapshots.rs b/tests/text_splitter_snapshots.rs index 110609bc..79a196c0 100644 --- a/tests/text_splitter_snapshots.rs +++ b/tests/text_splitter_snapshots.rs @@ -261,8 +261,8 @@ fn markdown() { let text = fs::read_to_string(path).unwrap(); for chunk_size in CHUNK_SIZES { - let splitter = MarkdownSplitter::default(); - let chunks = splitter.chunks(&text, chunk_size).collect::>(); + let splitter = MarkdownSplitter::new(ChunkConfig::new(chunk_size).with_trim(false)); + let chunks = splitter.chunks(&text).collect::>(); assert_eq!(chunks.join(""), text); for chunk in &chunks { @@ -280,8 +280,8 @@ fn markdown_trim() { let text = fs::read_to_string(path).unwrap(); for chunk_size in CHUNK_SIZES { - let splitter = MarkdownSplitter::default().with_trim_chunks(true); - let chunks = splitter.chunks(&text, chunk_size).collect::>(); + let splitter = MarkdownSplitter::new(chunk_size); + let chunks = splitter.chunks(&text).collect::>(); for chunk in &chunks { assert!(Characters.chunk_size(chunk, &chunk_size).fits().is_le()); @@ -298,8 +298,12 @@ fn huggingface_markdown() { let text = fs::read_to_string(path).unwrap(); for chunk_size in CHUNK_SIZES { - let splitter = MarkdownSplitter::new(&*HUGGINGFACE_TOKENIZER); - let chunks = splitter.chunks(&text, chunk_size).collect::>(); + let splitter = MarkdownSplitter::new( + ChunkConfig::new(chunk_size) + .with_sizer(&*HUGGINGFACE_TOKENIZER) + .with_trim(false), + ); + let chunks = splitter.chunks(&text).collect::>(); assert_eq!(chunks.join(""), text); for chunk in &chunks { @@ -320,8 +324,10 @@ fn huggingface_markdown_trim() { let text = fs::read_to_string(path).unwrap(); for chunk_size in CHUNK_SIZES { - let splitter = MarkdownSplitter::new(&*HUGGINGFACE_TOKENIZER).with_trim_chunks(true); - let chunks = splitter.chunks(&text, chunk_size).collect::>(); + let splitter = MarkdownSplitter::new( + ChunkConfig::new(chunk_size).with_sizer(&*HUGGINGFACE_TOKENIZER), + ); + let chunks = splitter.chunks(&text).collect::>(); for chunk in &chunks { assert!(HUGGINGFACE_TOKENIZER @@ -341,8 +347,12 @@ fn tiktoken_markdown() { let text = fs::read_to_string(path).unwrap(); for chunk_size in CHUNK_SIZES { - let splitter = MarkdownSplitter::new(&*TIKTOKEN_TOKENIZER); - let chunks = splitter.chunks(&text, chunk_size).collect::>(); + let splitter = MarkdownSplitter::new( + ChunkConfig::new(chunk_size) + .with_sizer(&*TIKTOKEN_TOKENIZER) + .with_trim(false), + ); + let chunks = splitter.chunks(&text).collect::>(); assert_eq!(chunks.join(""), text); for chunk in &chunks { @@ -363,8 +373,10 @@ fn tiktoken_markdown_trim() { let text = fs::read_to_string(path).unwrap(); for chunk_size in CHUNK_SIZES { - let splitter = MarkdownSplitter::new(&*TIKTOKEN_TOKENIZER).with_trim_chunks(true); - let chunks = splitter.chunks(&text, chunk_size).collect::>(); + let splitter = MarkdownSplitter::new( + ChunkConfig::new(chunk_size).with_sizer(&*TIKTOKEN_TOKENIZER), + ); + let chunks = splitter.chunks(&text).collect::>(); for chunk in &chunks { assert!(TIKTOKEN_TOKENIZER