Skip to content

Commit

Permalink
IL-239 new chunking for overlap with half-overlap size
Browse files Browse the repository at this point in the history
  • Loading branch information
FelixFehseTNG committed Feb 5, 2024
1 parent fdffb7a commit 847864e
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 9 deletions.
25 changes: 19 additions & 6 deletions src/intelligence_layer/core/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,20 +97,33 @@ def __init__(
)
)
self.chunk_task = ChunkTask(
client, model, max_tokens_per_chunk - overlap_length_tokens
client, model, overlap_length_tokens // 2
)
self.tokenizer = client.tokenizer(model)
self.max_tokens_per_chunk = max_tokens_per_chunk
self.overlap_length_tokens = overlap_length_tokens

def do_run(self, input: ChunkInput, task_span: TaskSpan) -> ChunkOutput:
chunks = self.chunk_task.run(input, task_span).chunks
id_chunks = self.tokenizer.encode_batch(chunks)

chunk_ids = [id_chunks[0].ids]
for i in range(len(id_chunks) - 1):
chunk_ids.append(
chunk_ids[i][-self.overlap_length_tokens :] + id_chunks[i + 1].ids
)

current_chunk = chunk_ids[0]
last_overlap = [chunk_ids[0]]
for chunk in id_chunks[1:]:
if len(chunk.ids) + len(current_chunk) <= self.max_tokens_per_chunk:
current_chunk.extend(chunk.ids)
else:
current_chunk = sum(last_overlap, []) + chunk.ids
chunk_ids.append(current_chunk)

last_overlap.append(chunk.ids)
total_length = len(sum(last_overlap, []))
# while total_length - len(last_overlap[0]) >= self.overlap_length_tokens:
while total_length > self.overlap_length_tokens:
total_length -= len(last_overlap[0])
last_overlap = last_overlap[1:]

print(chunk_ids)
decoded_chunks = self.tokenizer.decode_batch(chunk_ids)
return ChunkOutput(chunks=decoded_chunks)
11 changes: 8 additions & 3 deletions tests/core/test_chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ def test_overlapped_chunking(
client: AlephAlphaClientProtocol, some_large_text: str
) -> None:
MODEL = "luminous-base"
OVERLAP = 4
MAX_TOKENS = 10
OVERLAP = 8
MAX_TOKENS = 16

tracer = InMemoryTracer()
task = ChunkOverlapTask(
Expand All @@ -34,17 +34,22 @@ def test_overlapped_chunking(
output_tokenized = tokenizer.encode_batch(output.chunks)
for chunk_index in range(len(output_tokenized) - 1):
first = output_tokenized[chunk_index].tokens
print(first)

assert (
len(first) <= MAX_TOKENS + 2
) # `+2` because re-tokenizing the chunk can add a few extra tokens at the beginning or end of each chunk. This is a hack.
next = output_tokenized[chunk_index + 1].tokens

found = False
for offset in range(OVERLAP):
for offset in range(len(next)-OVERLAP//2):
if first[-OVERLAP // 2 :] != next[offset : offset + OVERLAP // 2]:
continue
found = True
break

if not found:
print("first = ", first)
print("next = ", next)

assert found

0 comments on commit 847864e

Please sign in to comment.