From a3a653fd7776ec84503418e585bd0dd4362dcd11 Mon Sep 17 00:00:00 2001 From: "niklas.finken" Date: Wed, 31 Jan 2024 14:48:28 +0100 Subject: [PATCH] refactor LongContextSummarize interface --- src/intelligence_layer/use_cases/__init__.py | 3 -- .../summarize/recursive_summarize.py | 28 ++--------------- .../use_cases/summarize/summarize.py | 3 +- .../summarize/test_recursive_summarize.py | 30 ++++--------------- 4 files changed, 9 insertions(+), 55 deletions(-) diff --git a/src/intelligence_layer/use_cases/__init__.py b/src/intelligence_layer/use_cases/__init__.py index c244132bb..0e4bffe61 100644 --- a/src/intelligence_layer/use_cases/__init__.py +++ b/src/intelligence_layer/use_cases/__init__.py @@ -59,9 +59,6 @@ LongContextMediumCompressionSummarize as LongContextMediumCompressionSummarize, ) from .summarize.recursive_summarize import RecursiveSummarize as RecursiveSummarize -from .summarize.recursive_summarize import ( - RecursiveSummarizeInput as RecursiveSummarizeInput, -) from .summarize.single_chunk_few_shot_summarize import ( SingleChunkFewShotSummarize as SingleChunkFewShotSummarize, ) diff --git a/src/intelligence_layer/use_cases/summarize/recursive_summarize.py b/src/intelligence_layer/use_cases/summarize/recursive_summarize.py index 4ae2d1e3e..80cebabdd 100644 --- a/src/intelligence_layer/use_cases/summarize/recursive_summarize.py +++ b/src/intelligence_layer/use_cases/summarize/recursive_summarize.py @@ -1,8 +1,3 @@ -from typing import Optional - -from pydantic import BaseModel - -from intelligence_layer.core.detect_language import Language from intelligence_layer.core.task import Task from intelligence_layer.core.tracer import TaskSpan from intelligence_layer.use_cases.summarize.summarize import ( @@ -12,23 +7,7 @@ ) -class RecursiveSummarizeInput(BaseModel): - """The Input for a recursive summarize task. - - Attributes: - text: A text of any length. - language: The desired language of the summary. ISO 619 str with language e.g. en, fr, etc. - max_tokens: The max number of tokens to be in the final summary. - max_loops: The max number of times to recursively summarize. - """ - - text: str - language: Language = Language("en") - max_tokens: Optional[int] = None - max_loops: Optional[int] = None - - -class RecursiveSummarize(Task[RecursiveSummarizeInput, SummarizeOutput]): +class RecursiveSummarize(Task[LongContextSummarizeInput, SummarizeOutput]): """Condenses a text recursively by summarizing summaries. Uses any long-context summarize task to go over text recursively and condense it even further. @@ -46,7 +25,7 @@ def __init__( self.long_context_summarize_task = long_context_summarize_task def do_run( - self, input: RecursiveSummarizeInput, task_span: TaskSpan + self, input: LongContextSummarizeInput, task_span: TaskSpan ) -> SummarizeOutput: text = input.text loop_count = 0 @@ -68,9 +47,6 @@ def do_run( elif input.max_tokens and num_generated_tokens < input.max_tokens: break - elif input.max_loops and loop_count <= input.max_loops: - break - return SummarizeOutput( summary=text.strip(), generated_tokens=num_generated_tokens ) diff --git a/src/intelligence_layer/use_cases/summarize/summarize.py b/src/intelligence_layer/use_cases/summarize/summarize.py index 4e6507f78..dcf717ebb 100644 --- a/src/intelligence_layer/use_cases/summarize/summarize.py +++ b/src/intelligence_layer/use_cases/summarize/summarize.py @@ -1,4 +1,4 @@ -from typing import Iterable, Sequence, Union +from typing import Iterable, Optional, Sequence, Union from pydantic import BaseModel @@ -24,6 +24,7 @@ class LongContextSummarizeInput(BaseModel): text: str language: Language = Language("en") + max_tokens: Optional[int] = None class PartialSummary(BaseModel): diff --git a/tests/use_cases/summarize/test_recursive_summarize.py b/tests/use_cases/summarize/test_recursive_summarize.py index 86bed1c3c..b0e95df1e 100644 --- a/tests/use_cases/summarize/test_recursive_summarize.py +++ b/tests/use_cases/summarize/test_recursive_summarize.py @@ -4,13 +4,11 @@ from aleph_alpha_client import Client, CompletionRequest, CompletionResponse from pytest import fixture -from intelligence_layer.core.tracer import NoOpTracer -from intelligence_layer.use_cases.summarize.long_context_high_compression_summarize import ( +from intelligence_layer.core import NoOpTracer +from intelligence_layer.use_cases import ( LongContextHighCompressionSummarize, -) -from intelligence_layer.use_cases.summarize.recursive_summarize import ( + LongContextSummarizeInput, RecursiveSummarize, - RecursiveSummarizeInput, ) @@ -45,7 +43,7 @@ def test_recursive_summarize_stops_when_hitting_max_tokens( long_context_high_compression_summarize: LongContextHighCompressionSummarize, ) -> None: max_tokens = 1000 - input = RecursiveSummarizeInput(text=very_long_text, max_tokens=max_tokens) + input = LongContextSummarizeInput(text=very_long_text, max_tokens=max_tokens) task = RecursiveSummarize(long_context_high_compression_summarize) output = task.run(input, NoOpTracer()) @@ -54,31 +52,13 @@ def test_recursive_summarize_stops_when_hitting_max_tokens( assert "new orleans" in output.summary.lower() -def test_recursive_summarize_stops_when_hitting_max_loops( - very_long_text: str, - recursive_counting_client: RecursiveCountingClient, -) -> None: - long_context_high_compression_summarize = LongContextHighCompressionSummarize( - recursive_counting_client, model="luminous-base" - ) - input = RecursiveSummarizeInput(text=very_long_text, max_loops=1) - task = RecursiveSummarize(long_context_high_compression_summarize) - output = task.run(input, NoOpTracer()) - - assert len(output.summary) < len(very_long_text) - assert ( - recursive_counting_client.recursive_counter == 71 - ) # text is chunked into 71 chunks - assert "new orleans" in output.summary.lower() - - def test_recursive_summarize_stops_after_one_chunk( recursive_counting_client: RecursiveCountingClient, ) -> None: long_context_high_compression_summarize = LongContextHighCompressionSummarize( recursive_counting_client, model="luminous-base" ) - input = RecursiveSummarizeInput(text=short_text) + input = LongContextSummarizeInput(text=short_text) task = RecursiveSummarize(long_context_high_compression_summarize) task.run(input, NoOpTracer())