Skip to content

Commit

Permalink
Il 239 chunk overlap task (#453)
Browse files Browse the repository at this point in the history
* IL-239 added ChunkOverlapTask

* IL-239 added `test_chunk.py`

* IL-239 add `overlap_length_tokens` to `SteerableLongContextSummarize`

* IL-239 new chunking for overlap with half-overlap size

* added break for recursive summary when number of partial summary doesn't change

---------

Co-authored-by: Johannes Wesch <[email protected]>
  • Loading branch information
FelixFehse and JohannesWesch authored Feb 6, 2024
1 parent 1ee4056 commit 3d01960
Show file tree
Hide file tree
Showing 5 changed files with 157 additions and 11 deletions.
58 changes: 58 additions & 0 deletions src/intelligence_layer/core/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,61 @@ def do_run(self, input: ChunkInput, task_span: TaskSpan) -> ChunkOutput:
for t in self._splitter.chunks(input.text, self._max_tokens_per_chunk)
]
return ChunkOutput(chunks=chunks)


class ChunkOverlapTask(Task[ChunkInput, ChunkOutput]):
"""Splits a longer text into smaller text chunks, where every chunk overlaps
with the previous chunk by `overlap_length_tokens` number of tokens.
Provide a text of any length and chunk it into smaller pieces using a
tokenizer that is available within the Aleph Alpha client.
Args:
client: Aleph Alpha client instance for running model related API calls.
model: A valid Aleph Alpha model name.
max_tokens_per_chunk: The maximum number of tokens to fit into one chunk.
overlap_length_tokens: The number of tokens every chunk overlaps with the previous chunk.
"""

def __init__(
self,
client: AlephAlphaClientProtocol,
model: str,
max_tokens_per_chunk: int,
overlap_length_tokens: int,
):
super().__init__()
if overlap_length_tokens >= max_tokens_per_chunk:
raise RuntimeError(
"Cannot choose an overlap ({}) longer than the chunk ({})".format(
overlap_length_tokens, max_tokens_per_chunk
)
)
self.chunk_task = ChunkTask(client, model, overlap_length_tokens // 2)
self.tokenizer = client.tokenizer(model)
self.max_tokens_per_chunk = max_tokens_per_chunk
self.overlap_length_tokens = overlap_length_tokens

def do_run(self, input: ChunkInput, task_span: TaskSpan) -> ChunkOutput:
chunks = self.chunk_task.run(input, task_span).chunks
id_chunks = self.tokenizer.encode_batch(chunks)

chunk_ids = [id_chunks[0].ids]
current_chunk = chunk_ids[0]
last_overlap = [chunk_ids[0]]
for chunk in id_chunks[1:]:
if len(chunk.ids) + len(current_chunk) <= self.max_tokens_per_chunk:
current_chunk.extend(chunk.ids)
else:
current_chunk = sum(last_overlap, []) + chunk.ids
chunk_ids.append(current_chunk)

last_overlap.append(chunk.ids)
total_length = len(sum(last_overlap, []))
while total_length > self.overlap_length_tokens:
total_length -= len(last_overlap[0])
last_overlap = last_overlap[1:]

print(chunk_ids)
decoded_chunks = self.tokenizer.decode_batch(chunk_ids)
return ChunkOutput(chunks=decoded_chunks)
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,22 @@ def __init__(
def do_run(
self, input: LongContextSummarizeInput, task_span: TaskSpan
) -> SummarizeOutput:
num_partial_summaries = 0
text = input.text
loop_count = 0
while True:
summarize_output = self.long_context_summarize_task.run(
LongContextSummarizeInput(text=text, language=input.language), task_span
)
if num_partial_summaries == len(summarize_output.partial_summaries):
break
num_partial_summaries = len(summarize_output.partial_summaries)

num_generated_tokens = 0
text = ""
for partial_summary in summarize_output.partial_summaries:
num_generated_tokens += partial_summary.generated_tokens
text += partial_summary.summary + "\n"

loop_count += 1

if len(summarize_output.partial_summaries) == 1:
break

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from intelligence_layer.connectors import AlephAlphaClientProtocol
from intelligence_layer.core import ChunkInput, ChunkTask, Task, TaskSpan
from intelligence_layer.core.chunk import ChunkOutput, ChunkOverlapTask
from intelligence_layer.core.detect_language import Language
from intelligence_layer.use_cases.summarize.steerable_single_chunk_summarize import (
SteerableSingleChunkSummarize,
Expand Down Expand Up @@ -31,31 +32,38 @@ class SteerableLongContextSummarize(
Args:
client: Aleph Alpha client instance for running model related API calls.
few_shot_configs: A mapping of valid `Language` to `FewShotConfig` for each
supported language.
model: A valid Aleph Alpha model name.
max_generated_tokens: The maximum number of tokens per sub-summary.
max_tokens_per_chunk: The maximum number of tokens per chunk that the long text
is divided into.
allowed_languages: List of languages to which the language detection is limited (ISO619).
fallback_language: The default language of the output.
model: A valid Aleph Alpha model name.
intruction_configs: Dictionary of the prompts for each language.
"""

def __init__(
self,
client: AlephAlphaClientProtocol,
max_generated_tokens: int,
max_tokens_per_chunk: int,
overlap_length_tokens: int = 0,
model: str = "luminous-base-control",
instruction_configs: Mapping[Language, str] = INSTRUCTION_CONFIGS,
) -> None:
super().__init__()
self._summarize = SteerableSingleChunkSummarize(
client, model, max_generated_tokens, instruction_configs
)
self._chunk_task = ChunkTask(
client, model=model, max_tokens_per_chunk=max_tokens_per_chunk
)
self._chunk_task: Task[ChunkInput, ChunkOutput]
if overlap_length_tokens == 0:
self._chunk_task = ChunkTask(
client, model=model, max_tokens_per_chunk=max_tokens_per_chunk
)
else:
self._chunk_task = ChunkOverlapTask(
client,
model=model,
max_tokens_per_chunk=max_tokens_per_chunk,
overlap_length_tokens=overlap_length_tokens,
)

def do_run(
self, input: LongContextSummarizeInput, task_span: TaskSpan
Expand Down
58 changes: 58 additions & 0 deletions tests/core/test_chunk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from pytest import fixture

from intelligence_layer.connectors import AlephAlphaClientProtocol
from intelligence_layer.core import InMemoryTracer
from intelligence_layer.core.chunk import ChunkInput, ChunkOverlapTask


@fixture
def some_large_text() -> str:
return """
The Williamsburgh Savings Bank Tower, also known as One Hanson Place, is a skyscraper in the Fort Greene neighborhood of Brooklyn in New York City. Located at the northeast corner of Ashland Place and Hanson Place near Downtown Brooklyn, the tower was designed by Halsey, McCormack & Helmer and constructed from 1927 to 1929 as the new headquarters for the Williamsburgh Savings Bank. At 41 stories and 512 feet (156 m) tall, the Williamsburgh Savings Bank Tower was the tallest building in Brooklyn until 2009.
The Williamsburgh Savings Bank was originally headquartered in Williamsburg, Brooklyn; its officers decided to construct a new skyscraper headquarters near Downtown Brooklyn in the mid-1920s. The bank occupied the lowest floors when the building opened on April 1, 1929, while the remaining stories were rented as offices. By the late 20th century, dentists' offices occupied much of the structure. The New York City Landmarks Preservation Commission designated the tower's exterior as a city landmark in 1977 and designated some of the interior spaces in 1996. Through several mergers, the Williamsburgh Savings Bank became part of HSBC Bank USA, which sold the building in 2004. The building's upper stories were converted to luxury condominium apartments from 2005 to 2007, while the banking hall became an event space.
"""


def test_overlapped_chunking(
client: AlephAlphaClientProtocol, some_large_text: str
) -> None:
MODEL = "luminous-base"
OVERLAP = 8
MAX_TOKENS = 16

tracer = InMemoryTracer()
task = ChunkOverlapTask(
client,
model=MODEL,
max_tokens_per_chunk=MAX_TOKENS,
overlap_length_tokens=OVERLAP,
)
output = task.run(ChunkInput(text=some_large_text), tracer)

tokenizer = client.tokenizer(MODEL)
output_tokenized = tokenizer.encode_batch(output.chunks)
for chunk_index in range(len(output_tokenized) - 1):
first = output_tokenized[chunk_index].tokens
print(first)

assert (
len(first)
<= MAX_TOKENS + 2
# `+2` because re-tokenizing the chunk can add a few extra tokens at
# the beginning or end of each chunk. This is a hack.
)
next = output_tokenized[chunk_index + 1].tokens

found = False
for offset in range(len(next) - OVERLAP // 2):
if first[-OVERLAP // 2 :] != next[offset : offset + OVERLAP // 2]:
continue
found = True
break

if not found:
print("first = ", first)
print("next = ", next)

assert found
20 changes: 20 additions & 0 deletions tests/use_cases/summarize/test_recursive_summarize.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,18 @@
from aleph_alpha_client import Client, CompletionRequest, CompletionResponse
from pytest import fixture

from intelligence_layer.connectors.limited_concurrency_client import (
AlephAlphaClientProtocol,
)
from intelligence_layer.core import NoOpTracer
from intelligence_layer.use_cases import (
LongContextHighCompressionSummarize,
LongContextSummarizeInput,
RecursiveSummarize,
)
from intelligence_layer.use_cases.summarize.steerable_long_context_summarize import (
SteerableLongContextSummarize,
)


class RecursiveCountingClient(Client):
Expand Down Expand Up @@ -52,6 +58,20 @@ def test_recursive_summarize_stops_when_hitting_max_tokens(
assert "new orleans" in output.summary.lower()


def test_recursive_summarize_stops_when_num_partial_summaries_stays_same(
client: AlephAlphaClientProtocol,
) -> None:
max_tokens = None
slcs = SteerableLongContextSummarize(
client, model="luminous-base", max_generated_tokens=75, max_tokens_per_chunk=145
)
input = LongContextSummarizeInput(text=short_text, max_tokens=max_tokens)
task = RecursiveSummarize(slcs)
output = task.run(input, NoOpTracer())

assert output.generated_tokens > 145


def test_recursive_summarize_stops_after_one_chunk(
recursive_counting_client: RecursiveCountingClient,
) -> None:
Expand Down

0 comments on commit 3d01960

Please sign in to comment.