From c91a2ad8b8c263dfae0ad43bb5600e109c2e7387 Mon Sep 17 00:00:00 2001 From: Luke Alvoeiro Date: Mon, 23 Sep 2024 12:21:40 -0700 Subject: [PATCH 1/7] feat: offset for summarizer moderator --- src/exchange/moderators/summarizer.py | 23 +++++++++++++++++++++-- src/exchange/moderators/truncate.py | 9 ++++++--- tests/test_summarizer.py | 16 +++++++--------- 3 files changed, 34 insertions(+), 14 deletions(-) diff --git a/src/exchange/moderators/summarizer.py b/src/exchange/moderators/summarizer.py index 7e2dd55..b05e31a 100644 --- a/src/exchange/moderators/summarizer.py +++ b/src/exchange/moderators/summarizer.py @@ -3,19 +3,38 @@ from exchange import Message from exchange.checkpoint import CheckpointData from exchange.moderators import ContextTruncate, PassiveModerator +from exchange.moderators.truncate import MAX_TOKENS + +# this offset is used to prevent summarization from happening too frequently +SUMMARIZATION_OFFSET = 30000 class ContextSummarizer(ContextTruncate): + def __init__( + self, + model: str = "gpt-4o-mini", + max_tokens: int = MAX_TOKENS, + summarization_offset: int = SUMMARIZATION_OFFSET, + ) -> None: + super().__init__(model=model, max_tokens=max_tokens) + self.summarization_offset = summarization_offset + def rewrite(self, exchange: Type["exchange.exchange.Exchange"]) -> None: # noqa: F821 """Summarize the context history up to the last few messages in the exchange""" self._update_system_prompt_token_count(exchange) - # TODO: use an offset for summarization + # note: we don't use the num_token_to_keep (defined below) here, because we only want to trigger + # summarization when we pass the max_token threshold. at that point, we will summarize more than + # we need to (using the offset), so we don't need to summarize on every ex.generate(...) call if exchange.checkpoint_data.total_token_count < self.max_tokens: return - messages_to_summarize = self._get_messages_to_remove(exchange) + assert self.summarization_offset < self.max_tokens + + num_tokens_to_keep = self.max_tokens - self.summarization_offset + + messages_to_summarize = self._get_messages_to_remove(exchange, max_tokens=num_tokens_to_keep) num_messages_to_remove = len(messages_to_summarize) # the llm will throw an error if the last message isn't a user message diff --git a/src/exchange/moderators/truncate.py b/src/exchange/moderators/truncate.py index 41115f6..f0c5059 100644 --- a/src/exchange/moderators/truncate.py +++ b/src/exchange/moderators/truncate.py @@ -62,15 +62,18 @@ def _update_system_prompt_token_count(self, exchange: Exchange) -> None: exchange.checkpoint_data.total_token_count -= last_system_prompt_token_count exchange.checkpoint_data.total_token_count += self.system_prompt_token_count - def _get_messages_to_remove(self, exchange: Exchange) -> List[Message]: + def _get_messages_to_remove(self, exchange: Exchange, max_tokens: Optional[int] = None) -> List[Message]: + if not max_tokens: + max_tokens = self.max_tokens + # this keeps all the messages/checkpoints throwaway_exchange = exchange.replace( moderator=PassiveModerator(), ) - # get the messages that we want to remove + # get the messages that we want to summarize messages_to_remove = [] - while throwaway_exchange.checkpoint_data.total_token_count > self.max_tokens: + while throwaway_exchange.checkpoint_data.total_token_count > max_tokens: _, messages = throwaway_exchange.pop_first_checkpoint() messages_to_remove.extend(messages) diff --git a/tests/test_summarizer.py b/tests/test_summarizer.py index fa72819..ddac0a2 100644 --- a/tests/test_summarizer.py +++ b/tests/test_summarizer.py @@ -49,7 +49,7 @@ def exchange_instance(): @pytest.fixture def summarizer_instance(): - return ContextSummarizer(max_tokens=300) + return ContextSummarizer(max_tokens=300, summarization_offset=100) def test_context_summarizer_rewrite(exchange_instance: Exchange, summarizer_instance: ContextSummarizer): @@ -199,9 +199,7 @@ def conversation_exchange_instance(): provider=AnotherMockProvider(), model="test-model", system="test-system", - moderator=ContextSummarizer(max_tokens=300), - # TODO: make it work with an offset so we don't have to send off requests basically - # at every generate step + moderator=ContextSummarizer(max_tokens=300, summarization_offset=100), ) return ex @@ -215,11 +213,11 @@ def test_summarizer_generic_conversation(conversation_exchange_instance: Exchang if message.text != "Summary message here": i += 2 checkpoints = conversation_exchange_instance.checkpoint_data.checkpoints - assert conversation_exchange_instance.checkpoint_data.total_token_count == 570 - assert len(checkpoints) == 10 - assert len(conversation_exchange_instance.messages) == 10 - assert checkpoints[0].start_index == 20 - assert checkpoints[0].end_index == 20 + assert conversation_exchange_instance.checkpoint_data.total_token_count == 412 + assert len(checkpoints) == 5 + assert len(conversation_exchange_instance.messages) == 5 + assert checkpoints[0].start_index == 25 + assert checkpoints[0].end_index == 25 assert checkpoints[-1].start_index == 29 assert checkpoints[-1].end_index == 29 assert conversation_exchange_instance.checkpoint_data.message_index_offset == 20 From 911cd645498c5fe801b44046f103024f84190025 Mon Sep 17 00:00:00 2001 From: Luke Alvoeiro Date: Mon, 23 Sep 2024 12:41:45 -0700 Subject: [PATCH 2/7] feat: always stay below token count --- src/exchange/exchange.py | 7 ++----- tests/test_summarizer.py | 16 ++++++++-------- 2 files changed, 10 insertions(+), 13 deletions(-) diff --git a/src/exchange/exchange.py b/src/exchange/exchange.py index feece4f..008ee7a 100644 --- a/src/exchange/exchange.py +++ b/src/exchange/exchange.py @@ -81,11 +81,8 @@ def generate(self) -> Message: self.add(message) self.add_checkpoints_from_usage(usage) # this has to come after adding the response - # TODO: also call `rewrite` here, as this will make our - # messages *consistently* below the token limit. this currently - # is not the case because we could append a large message after calling - # `rewrite` above. - # self.moderator.rewrite(self) + # also call `rewrite` here, as this will make our messages are consistently below the token limit. + self.moderator.rewrite(self) _token_usage_collector.collect(self.model, usage) return message diff --git a/tests/test_summarizer.py b/tests/test_summarizer.py index ddac0a2..701a390 100644 --- a/tests/test_summarizer.py +++ b/tests/test_summarizer.py @@ -149,7 +149,6 @@ def complete(self, model, system, messages, tools): system_prompt_tokens = 100 input_token_count = system_prompt_tokens - message = self.sequence[self.current_index] if self.summarize_next: text = "Summary message here" self.summarize_next = False @@ -160,6 +159,7 @@ def complete(self, model, system, messages, tools): output_tokens=len(text) * 2, total_tokens=40 + len(text) * 2, ) + message = self.sequence[self.current_index] if len(messages) > 0 and type(messages[0].content[0]) is ToolResult: raise ValueError("ToolResult should not be the first message") @@ -213,13 +213,13 @@ def test_summarizer_generic_conversation(conversation_exchange_instance: Exchang if message.text != "Summary message here": i += 2 checkpoints = conversation_exchange_instance.checkpoint_data.checkpoints - assert conversation_exchange_instance.checkpoint_data.total_token_count == 412 - assert len(checkpoints) == 5 - assert len(conversation_exchange_instance.messages) == 5 - assert checkpoints[0].start_index == 25 - assert checkpoints[0].end_index == 25 + assert conversation_exchange_instance.checkpoint_data.total_token_count == 148 + assert len(checkpoints) == 4 + assert len(conversation_exchange_instance.messages) == 4 + assert checkpoints[0].start_index == 26 + assert checkpoints[0].end_index == 26 assert checkpoints[-1].start_index == 29 assert checkpoints[-1].end_index == 29 - assert conversation_exchange_instance.checkpoint_data.message_index_offset == 20 - assert conversation_exchange_instance.provider.summarized_count == 12 + assert conversation_exchange_instance.checkpoint_data.message_index_offset == 26 + assert conversation_exchange_instance.provider.summarized_count == 10 assert conversation_exchange_instance.moderator.system_prompt_token_count == 100 From 3a4be82966c07222f5ad704ed64249f945827f05 Mon Sep 17 00:00:00 2001 From: Luke Alvoeiro Date: Mon, 23 Sep 2024 12:42:53 -0700 Subject: [PATCH 3/7] fix: increase token limit since we will now consistently stay below it --- src/exchange/moderators/truncate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/exchange/moderators/truncate.py b/src/exchange/moderators/truncate.py index f0c5059..e80b8d2 100644 --- a/src/exchange/moderators/truncate.py +++ b/src/exchange/moderators/truncate.py @@ -14,7 +14,7 @@ # so once we get to this token size the token count will exceed this # by a little bit. # TODO: make this configurable for each provider -MAX_TOKENS = 100000 +MAX_TOKENS = 128000 class ContextTruncate(Moderator): From 148136e4a856be4028bc3ac6fe6e6f0d5aae6314 Mon Sep 17 00:00:00 2001 From: Luke Alvoeiro Date: Mon, 23 Sep 2024 12:46:54 -0700 Subject: [PATCH 4/7] feat: enable summarize moderator as the default moderator --- src/exchange/exchange.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/exchange/exchange.py b/src/exchange/exchange.py index 008ee7a..1150d69 100644 --- a/src/exchange/exchange.py +++ b/src/exchange/exchange.py @@ -10,7 +10,7 @@ from exchange.content import Text, ToolResult, ToolUse from exchange.message import Message from exchange.moderators import Moderator -from exchange.moderators.truncate import ContextTruncate +from exchange.moderators.summarizer import ContextSummarizer from exchange.providers import Provider, Usage from exchange.tool import Tool from exchange.token_usage_collector import _token_usage_collector @@ -40,7 +40,7 @@ class Exchange: provider: Provider model: str system: str - moderator: Moderator = field(default=ContextTruncate()) + moderator: Moderator = field(default=ContextSummarizer()) tools: Tuple[Tool] = field(factory=tuple, converter=tuple) messages: List[Message] = field(factory=list) checkpoint_data: CheckpointData = field(factory=CheckpointData) From 42933d3e6aee691a4d188a5fa0d7435e76f9705f Mon Sep 17 00:00:00 2001 From: Luke Alvoeiro Date: Mon, 23 Sep 2024 12:52:23 -0700 Subject: [PATCH 5/7] chore: update check in `test_truncate` to ensure no regressions --- tests/test_truncate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_truncate.py b/tests/test_truncate.py index 3875303..24ae727 100644 --- a/tests/test_truncate.py +++ b/tests/test_truncate.py @@ -128,5 +128,5 @@ def test_truncate_on_generic_conversation(conversation_exchange_instance: Exchan if message.text != "Summary message here": i += 2 # ensure the total token count is not anything exhorbitant - assert conversation_exchange_instance.checkpoint_data.total_token_count < 700 + assert conversation_exchange_instance.checkpoint_data.total_token_count < 500 assert conversation_exchange_instance.moderator.system_prompt_token_count == 100 From 667296c79cdbc8dba271181bae53617992e57bae Mon Sep 17 00:00:00 2001 From: Adrian Cole Date: Wed, 25 Sep 2024 09:18:35 +1000 Subject: [PATCH 6/7] test: use default model in ollama and reduce code redundancy Signed-off-by: Adrian Cole --- .github/workflows/ci.yaml | 12 +++----- ...pletion.yaml => test_ollama_complete.yaml} | 0 ...pletion.yaml => test_openai_complete.yaml} | 0 tests/providers/openai/conftest.py | 17 ++++++++++- tests/providers/openai/test_ollama.py | 24 ++++----------- tests/providers/openai/test_openai.py | 30 ++++--------------- 6 files changed, 32 insertions(+), 51 deletions(-) rename tests/providers/openai/cassettes/{test_ollama_completion.yaml => test_ollama_complete.yaml} (100%) rename tests/providers/openai/cassettes/{test_openai_completion.yaml => test_openai_complete.yaml} (100%) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 6999219..ed9bad1 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -43,11 +43,7 @@ jobs: strategy: matrix: python-version: - # Only test the lastest python version. - - "3.12" - ollama-model: - # For quicker CI, use a smaller, tool-capable model than the default. - - "qwen2.5:0.5b" + - "3.12" # Only test the latest python version. steps: - uses: actions/checkout@v4 @@ -75,11 +71,11 @@ jobs: # Tests use OpenAI which does not have a mechanism to pull models. Run a # simple prompt to (pull and) test the model first. - name: Test Ollama model - run: ollama run $OLLAMA_MODEL hello || cat ollama.log + run: | + OLLAMA_MODEL=$(uv run python -c "from src.exchange.providers.ollama import OLLAMA_MODEL; print(OLLAMA_MODEL)") + ollama run $OLLAMA_MODEL hello || cat ollama.log env: OLLAMA_MODEL: ${{ matrix.ollama-model }} - name: Run Ollama tests run: uv run pytest tests -m integration -k ollama - env: - OLLAMA_MODEL: ${{ matrix.ollama-model }} diff --git a/tests/providers/openai/cassettes/test_ollama_completion.yaml b/tests/providers/openai/cassettes/test_ollama_complete.yaml similarity index 100% rename from tests/providers/openai/cassettes/test_ollama_completion.yaml rename to tests/providers/openai/cassettes/test_ollama_complete.yaml diff --git a/tests/providers/openai/cassettes/test_openai_completion.yaml b/tests/providers/openai/cassettes/test_openai_complete.yaml similarity index 100% rename from tests/providers/openai/cassettes/test_openai_completion.yaml rename to tests/providers/openai/cassettes/test_openai_complete.yaml diff --git a/tests/providers/openai/conftest.py b/tests/providers/openai/conftest.py index e282c87..c752e73 100644 --- a/tests/providers/openai/conftest.py +++ b/tests/providers/openai/conftest.py @@ -1,11 +1,19 @@ import os +from typing import Type, Tuple + import pytest -OPENAI_MODEL = "gpt-4o-mini" +from exchange import Message +from exchange.providers import Usage, Provider +from exchange.providers.ollama import OLLAMA_MODEL + OPENAI_API_KEY = "test_openai_api_key" OPENAI_ORG_ID = "test_openai_org_key" OPENAI_PROJECT_ID = "test_openai_project_id" +OLLAMA_MODEL = os.getenv("OLLAMA_MODEL", OLLAMA_MODEL) +OPENAI_MODEL = os.getenv("OPENAI_MODEL", "gpt-4o-mini") + @pytest.fixture def default_openai_api_key(monkeypatch): @@ -50,3 +58,10 @@ def scrub_response_headers(response): response["headers"]["openai-organization"] = OPENAI_ORG_ID response["headers"]["Set-Cookie"] = "test_set_cookie" return response + + +def complete(provider_cls: Type[Provider], model: str) -> Tuple[Message, Usage]: + provider = provider_cls.from_env() + system = "You are a helpful assistant." + messages = [Message.user("Hello")] + return provider.complete(model=model, system=system, messages=messages, tools=None) diff --git a/tests/providers/openai/test_ollama.py b/tests/providers/openai/test_ollama.py index 0aa4646..d8a02ad 100644 --- a/tests/providers/openai/test_ollama.py +++ b/tests/providers/openai/test_ollama.py @@ -1,33 +1,21 @@ -from typing import Tuple - -import os import pytest from exchange import Text -from exchange.message import Message -from exchange.providers.base import Usage -from exchange.providers.ollama import OllamaProvider, OLLAMA_MODEL +from exchange.providers.ollama import OllamaProvider +from .conftest import complete, OLLAMA_MODEL @pytest.mark.vcr() -def test_ollama_completion(default_openai_api_key): - reply_message, reply_usage = ollama_complete() +def test_ollama_complete(default_openai_api_key): + reply_message, reply_usage = complete(OllamaProvider, OLLAMA_MODEL) assert reply_message.content == [Text(text="Hello! I'm here to help. How can I assist you today? Let's chat. 😊")] assert reply_usage.total_tokens == 33 @pytest.mark.integration -def test_ollama_completion_integration(): - reply = ollama_complete() +def test_ollama_complete_integration(): + reply = complete(OllamaProvider, OLLAMA_MODEL) assert reply[0].content is not None print("Completion content from OpenAI:", reply[0].content) - - -def ollama_complete() -> Tuple[Message, Usage]: - provider = OllamaProvider.from_env() - model = os.getenv("OLLAMA_MODEL", OLLAMA_MODEL) - system = "You are a helpful assistant." - messages = [Message.user("Hello")] - return provider.complete(model=model, system=system, messages=messages, tools=None) diff --git a/tests/providers/openai/test_openai.py b/tests/providers/openai/test_openai.py index 354d1c5..5f72981 100644 --- a/tests/providers/openai/test_openai.py +++ b/tests/providers/openai/test_openai.py @@ -1,39 +1,21 @@ -from typing import Tuple - -import os import pytest from exchange import Text -from exchange.message import Message -from exchange.providers.base import Usage from exchange.providers.openai import OpenAiProvider -from .conftest import OPENAI_MODEL, OPENAI_API_KEY +from .conftest import complete, OPENAI_MODEL @pytest.mark.vcr() -def test_openai_completion(monkeypatch): - # When running VCR tests the first time, it needs OPENAI_API_KEY to call - # the real service. Afterward, it is not needed as VCR mocks the service. - if "OPENAI_API_KEY" not in os.environ: - monkeypatch.setenv("OPENAI_API_KEY", OPENAI_API_KEY) - - reply_message, reply_usage = openai_complete() +def test_openai_complete(default_openai_api_key): + reply_message, reply_usage = complete(OpenAiProvider, OPENAI_MODEL) assert reply_message.content == [Text(text="Hello! How can I assist you today?")] assert reply_usage.total_tokens == 27 @pytest.mark.integration -def test_openai_completion_integration(): - reply = openai_complete() +def test_openai_complete_integration(): + reply = complete(OpenAiProvider, OPENAI_MODEL) assert reply[0].content is not None - print("Completion content from OpenAI:", reply[0].content) - - -def openai_complete() -> Tuple[Message, Usage]: - provider = OpenAiProvider.from_env() - model = OPENAI_MODEL - system = "You are a helpful assistant." - messages = [Message.user("Hello")] - return provider.complete(model=model, system=system, messages=messages, tools=None) + print("Complete content from OpenAI:", reply[0].content) From 0e3632c6d7415f4e42fafa6ebae789e20d34b383 Mon Sep 17 00:00:00 2001 From: Mic Neale Date: Wed, 25 Sep 2024 10:09:06 +1000 Subject: [PATCH 7/7] change default back --- src/exchange/exchange.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/exchange/exchange.py b/src/exchange/exchange.py index 1150d69..008ee7a 100644 --- a/src/exchange/exchange.py +++ b/src/exchange/exchange.py @@ -10,7 +10,7 @@ from exchange.content import Text, ToolResult, ToolUse from exchange.message import Message from exchange.moderators import Moderator -from exchange.moderators.summarizer import ContextSummarizer +from exchange.moderators.truncate import ContextTruncate from exchange.providers import Provider, Usage from exchange.tool import Tool from exchange.token_usage_collector import _token_usage_collector @@ -40,7 +40,7 @@ class Exchange: provider: Provider model: str system: str - moderator: Moderator = field(default=ContextSummarizer()) + moderator: Moderator = field(default=ContextTruncate()) tools: Tuple[Tool] = field(factory=tuple, converter=tuple) messages: List[Message] = field(factory=list) checkpoint_data: CheckpointData = field(factory=CheckpointData)