From 128e8cdcb4b2fc6942b098a28d447dc7f7b127bc Mon Sep 17 00:00:00 2001 From: Felix Fehse Date: Mon, 5 Feb 2024 18:20:10 +0100 Subject: [PATCH] IL-239 new chunking for overlap with half-overlap size --- src/intelligence_layer/core/chunk.py | 24 ++++++++++++++++++------ tests/core/test_chunk.py | 15 +++++++++++---- 2 files changed, 29 insertions(+), 10 deletions(-) diff --git a/src/intelligence_layer/core/chunk.py b/src/intelligence_layer/core/chunk.py index 22bf2646e..0bdd6b566 100644 --- a/src/intelligence_layer/core/chunk.py +++ b/src/intelligence_layer/core/chunk.py @@ -97,9 +97,10 @@ def __init__( ) ) self.chunk_task = ChunkTask( - client, model, max_tokens_per_chunk - overlap_length_tokens + client, model, overlap_length_tokens // 2 ) self.tokenizer = client.tokenizer(model) + self.max_tokens_per_chunk = max_tokens_per_chunk self.overlap_length_tokens = overlap_length_tokens def do_run(self, input: ChunkInput, task_span: TaskSpan) -> ChunkOutput: @@ -107,10 +108,21 @@ def do_run(self, input: ChunkInput, task_span: TaskSpan) -> ChunkOutput: id_chunks = self.tokenizer.encode_batch(chunks) chunk_ids = [id_chunks[0].ids] - for i in range(len(id_chunks) - 1): - chunk_ids.append( - chunk_ids[i][-self.overlap_length_tokens :] + id_chunks[i + 1].ids - ) - + current_chunk = chunk_ids[0] + last_overlap = [chunk_ids[0]] + for chunk in id_chunks[1:]: + if len(chunk.ids) + len(current_chunk) <= self.max_tokens_per_chunk: + current_chunk.extend(chunk.ids) + else: + current_chunk = sum(last_overlap, []) + chunk.ids + chunk_ids.append(current_chunk) + + last_overlap.append(chunk.ids) + total_length = len(sum(last_overlap, [])) + while total_length > self.overlap_length_tokens: + total_length -= len(last_overlap[0]) + last_overlap = last_overlap[1:] + + print(chunk_ids) decoded_chunks = self.tokenizer.decode_batch(chunk_ids) return ChunkOutput(chunks=decoded_chunks) diff --git a/tests/core/test_chunk.py b/tests/core/test_chunk.py index 04f278f24..b2c6ca1bc 100644 --- a/tests/core/test_chunk.py +++ b/tests/core/test_chunk.py @@ -18,8 +18,8 @@ def test_overlapped_chunking( client: AlephAlphaClientProtocol, some_large_text: str ) -> None: MODEL = "luminous-base" - OVERLAP = 4 - MAX_TOKENS = 10 + OVERLAP = 8 + MAX_TOKENS = 16 tracer = InMemoryTracer() task = ChunkOverlapTask( @@ -34,17 +34,24 @@ def test_overlapped_chunking( output_tokenized = tokenizer.encode_batch(output.chunks) for chunk_index in range(len(output_tokenized) - 1): first = output_tokenized[chunk_index].tokens + print(first) assert ( len(first) <= MAX_TOKENS + 2 - ) # `+2` because re-tokenizing the chunk can add a few extra tokens at the beginning or end of each chunk. This is a hack. + # `+2` because re-tokenizing the chunk can add a few extra tokens at + # the beginning or end of each chunk. This is a hack. + ) next = output_tokenized[chunk_index + 1].tokens found = False - for offset in range(OVERLAP): + for offset in range(len(next)-OVERLAP//2): if first[-OVERLAP // 2 :] != next[offset : offset + OVERLAP // 2]: continue found = True break + if not found: + print("first = ", first) + print("next = ", next) + assert found