diff --git a/src/intelligence_layer/use_cases/summarize/steerable_long_context_summarize.py b/src/intelligence_layer/use_cases/summarize/steerable_long_context_summarize.py index 8cba09b7f..c34c496ce 100644 --- a/src/intelligence_layer/use_cases/summarize/steerable_long_context_summarize.py +++ b/src/intelligence_layer/use_cases/summarize/steerable_long_context_summarize.py @@ -2,6 +2,7 @@ from intelligence_layer.connectors import AlephAlphaClientProtocol from intelligence_layer.core import ChunkInput, ChunkTask, Task, TaskSpan +from intelligence_layer.core.chunk import ChunkOutput, ChunkOverlapTask from intelligence_layer.core.detect_language import Language from intelligence_layer.use_cases.summarize.steerable_single_chunk_summarize import ( SteerableSingleChunkSummarize, @@ -43,6 +44,7 @@ def __init__( client: AlephAlphaClientProtocol, max_generated_tokens: int, max_tokens_per_chunk: int, + overlap_length_tokens: int = 0, model: str = "luminous-base-control", instruction_configs: Mapping[Language, str] = INSTRUCTION_CONFIGS, ) -> None: @@ -50,9 +52,18 @@ def __init__( self._summarize = SteerableSingleChunkSummarize( client, model, max_generated_tokens, instruction_configs ) - self._chunk_task = ChunkTask( - client, model=model, max_tokens_per_chunk=max_tokens_per_chunk - ) + self._chunk_task: Task[ChunkInput, ChunkOutput] + if overlap_length_tokens == 0: + self._chunk_task = ChunkTask( + client, model=model, max_tokens_per_chunk=max_tokens_per_chunk + ) + else: + self._chunk_task = ChunkOverlapTask( + client, + model=model, + max_tokens_per_chunk=max_tokens_per_chunk, + overlap_length_tokens=overlap_length_tokens, + ) def do_run( self, input: LongContextSummarizeInput, task_span: TaskSpan