Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: [testing] Mic/todos in moderators fix #56

Closed
wants to merge 8 commits into from
12 changes: 4 additions & 8 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,7 @@ jobs:
strategy:
matrix:
python-version:
# Only test the lastest python version.
- "3.12"
ollama-model:
# For quicker CI, use a smaller, tool-capable model than the default.
- "qwen2.5:0.5b"
- "3.12" # Only test the latest python version.

steps:
- uses: actions/checkout@v4
Expand Down Expand Up @@ -75,11 +71,11 @@ jobs:
# Tests use OpenAI which does not have a mechanism to pull models. Run a
# simple prompt to (pull and) test the model first.
- name: Test Ollama model
run: ollama run $OLLAMA_MODEL hello || cat ollama.log
run: |
OLLAMA_MODEL=$(uv run python -c "from src.exchange.providers.ollama import OLLAMA_MODEL; print(OLLAMA_MODEL)")
ollama run $OLLAMA_MODEL hello || cat ollama.log
env:
OLLAMA_MODEL: ${{ matrix.ollama-model }}

- name: Run Ollama tests
run: uv run pytest tests -m integration -k ollama
env:
OLLAMA_MODEL: ${{ matrix.ollama-model }}
7 changes: 2 additions & 5 deletions src/exchange/exchange.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,8 @@ def generate(self) -> Message:
self.add(message)
self.add_checkpoints_from_usage(usage) # this has to come after adding the response

# TODO: also call `rewrite` here, as this will make our
# messages *consistently* below the token limit. this currently
# is not the case because we could append a large message after calling
# `rewrite` above.
# self.moderator.rewrite(self)
# also call `rewrite` here, as this will make our messages are consistently below the token limit.
self.moderator.rewrite(self)

_token_usage_collector.collect(self.model, usage)
return message
Expand Down
23 changes: 21 additions & 2 deletions src/exchange/moderators/summarizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,38 @@
from exchange import Message
from exchange.checkpoint import CheckpointData
from exchange.moderators import ContextTruncate, PassiveModerator
from exchange.moderators.truncate import MAX_TOKENS

# this offset is used to prevent summarization from happening too frequently
SUMMARIZATION_OFFSET = 30000


class ContextSummarizer(ContextTruncate):
def __init__(
self,
model: str = "gpt-4o-mini",
max_tokens: int = MAX_TOKENS,
summarization_offset: int = SUMMARIZATION_OFFSET,
) -> None:
super().__init__(model=model, max_tokens=max_tokens)
self.summarization_offset = summarization_offset

def rewrite(self, exchange: Type["exchange.exchange.Exchange"]) -> None: # noqa: F821
"""Summarize the context history up to the last few messages in the exchange"""

self._update_system_prompt_token_count(exchange)

# TODO: use an offset for summarization
# note: we don't use the num_token_to_keep (defined below) here, because we only want to trigger
# summarization when we pass the max_token threshold. at that point, we will summarize more than
# we need to (using the offset), so we don't need to summarize on every ex.generate(...) call
if exchange.checkpoint_data.total_token_count < self.max_tokens:
return

messages_to_summarize = self._get_messages_to_remove(exchange)
assert self.summarization_offset < self.max_tokens

num_tokens_to_keep = self.max_tokens - self.summarization_offset

messages_to_summarize = self._get_messages_to_remove(exchange, max_tokens=num_tokens_to_keep)
num_messages_to_remove = len(messages_to_summarize)

# the llm will throw an error if the last message isn't a user message
Expand Down
11 changes: 7 additions & 4 deletions src/exchange/moderators/truncate.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# so once we get to this token size the token count will exceed this
# by a little bit.
# TODO: make this configurable for each provider
MAX_TOKENS = 100000
MAX_TOKENS = 128000


class ContextTruncate(Moderator):
Expand Down Expand Up @@ -62,15 +62,18 @@ def _update_system_prompt_token_count(self, exchange: Exchange) -> None:
exchange.checkpoint_data.total_token_count -= last_system_prompt_token_count
exchange.checkpoint_data.total_token_count += self.system_prompt_token_count

def _get_messages_to_remove(self, exchange: Exchange) -> List[Message]:
def _get_messages_to_remove(self, exchange: Exchange, max_tokens: Optional[int] = None) -> List[Message]:
if not max_tokens:
max_tokens = self.max_tokens

# this keeps all the messages/checkpoints
throwaway_exchange = exchange.replace(
moderator=PassiveModerator(),
)

# get the messages that we want to remove
# get the messages that we want to summarize
messages_to_remove = []
while throwaway_exchange.checkpoint_data.total_token_count > self.max_tokens:
while throwaway_exchange.checkpoint_data.total_token_count > max_tokens:
_, messages = throwaway_exchange.pop_first_checkpoint()
messages_to_remove.extend(messages)

Expand Down
17 changes: 16 additions & 1 deletion tests/providers/openai/conftest.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
import os
from typing import Type, Tuple

import pytest

OPENAI_MODEL = "gpt-4o-mini"
from exchange import Message
from exchange.providers import Usage, Provider
from exchange.providers.ollama import OLLAMA_MODEL

OPENAI_API_KEY = "test_openai_api_key"
OPENAI_ORG_ID = "test_openai_org_key"
OPENAI_PROJECT_ID = "test_openai_project_id"

OLLAMA_MODEL = os.getenv("OLLAMA_MODEL", OLLAMA_MODEL)
OPENAI_MODEL = os.getenv("OPENAI_MODEL", "gpt-4o-mini")


@pytest.fixture
def default_openai_api_key(monkeypatch):
Expand Down Expand Up @@ -50,3 +58,10 @@ def scrub_response_headers(response):
response["headers"]["openai-organization"] = OPENAI_ORG_ID
response["headers"]["Set-Cookie"] = "test_set_cookie"
return response


def complete(provider_cls: Type[Provider], model: str) -> Tuple[Message, Usage]:
provider = provider_cls.from_env()
system = "You are a helpful assistant."
messages = [Message.user("Hello")]
return provider.complete(model=model, system=system, messages=messages, tools=None)
24 changes: 6 additions & 18 deletions tests/providers/openai/test_ollama.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,21 @@
from typing import Tuple

import os
import pytest

from exchange import Text
from exchange.message import Message
from exchange.providers.base import Usage
from exchange.providers.ollama import OllamaProvider, OLLAMA_MODEL
from exchange.providers.ollama import OllamaProvider
from .conftest import complete, OLLAMA_MODEL


@pytest.mark.vcr()
def test_ollama_completion(default_openai_api_key):
reply_message, reply_usage = ollama_complete()
def test_ollama_complete(default_openai_api_key):
reply_message, reply_usage = complete(OllamaProvider, OLLAMA_MODEL)

assert reply_message.content == [Text(text="Hello! I'm here to help. How can I assist you today? Let's chat. 😊")]
assert reply_usage.total_tokens == 33


@pytest.mark.integration
def test_ollama_completion_integration():
reply = ollama_complete()
def test_ollama_complete_integration():
reply = complete(OllamaProvider, OLLAMA_MODEL)

assert reply[0].content is not None
print("Completion content from OpenAI:", reply[0].content)


def ollama_complete() -> Tuple[Message, Usage]:
provider = OllamaProvider.from_env()
model = os.getenv("OLLAMA_MODEL", OLLAMA_MODEL)
system = "You are a helpful assistant."
messages = [Message.user("Hello")]
return provider.complete(model=model, system=system, messages=messages, tools=None)
30 changes: 6 additions & 24 deletions tests/providers/openai/test_openai.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,21 @@
from typing import Tuple

import os
import pytest

from exchange import Text
from exchange.message import Message
from exchange.providers.base import Usage
from exchange.providers.openai import OpenAiProvider
from .conftest import OPENAI_MODEL, OPENAI_API_KEY
from .conftest import complete, OPENAI_MODEL


@pytest.mark.vcr()
def test_openai_completion(monkeypatch):
# When running VCR tests the first time, it needs OPENAI_API_KEY to call
# the real service. Afterward, it is not needed as VCR mocks the service.
if "OPENAI_API_KEY" not in os.environ:
monkeypatch.setenv("OPENAI_API_KEY", OPENAI_API_KEY)

reply_message, reply_usage = openai_complete()
def test_openai_complete(default_openai_api_key):
reply_message, reply_usage = complete(OpenAiProvider, OPENAI_MODEL)

assert reply_message.content == [Text(text="Hello! How can I assist you today?")]
assert reply_usage.total_tokens == 27


@pytest.mark.integration
def test_openai_completion_integration():
reply = openai_complete()
def test_openai_complete_integration():
reply = complete(OpenAiProvider, OPENAI_MODEL)

assert reply[0].content is not None
print("Completion content from OpenAI:", reply[0].content)


def openai_complete() -> Tuple[Message, Usage]:
provider = OpenAiProvider.from_env()
model = OPENAI_MODEL
system = "You are a helpful assistant."
messages = [Message.user("Hello")]
return provider.complete(model=model, system=system, messages=messages, tools=None)
print("Complete content from OpenAI:", reply[0].content)
22 changes: 10 additions & 12 deletions tests/test_summarizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def exchange_instance():

@pytest.fixture
def summarizer_instance():
return ContextSummarizer(max_tokens=300)
return ContextSummarizer(max_tokens=300, summarization_offset=100)


def test_context_summarizer_rewrite(exchange_instance: Exchange, summarizer_instance: ContextSummarizer):
Expand Down Expand Up @@ -149,7 +149,6 @@ def complete(self, model, system, messages, tools):
system_prompt_tokens = 100
input_token_count = system_prompt_tokens

message = self.sequence[self.current_index]
if self.summarize_next:
text = "Summary message here"
self.summarize_next = False
Expand All @@ -160,6 +159,7 @@ def complete(self, model, system, messages, tools):
output_tokens=len(text) * 2,
total_tokens=40 + len(text) * 2,
)
message = self.sequence[self.current_index]

if len(messages) > 0 and type(messages[0].content[0]) is ToolResult:
raise ValueError("ToolResult should not be the first message")
Expand Down Expand Up @@ -199,9 +199,7 @@ def conversation_exchange_instance():
provider=AnotherMockProvider(),
model="test-model",
system="test-system",
moderator=ContextSummarizer(max_tokens=300),
# TODO: make it work with an offset so we don't have to send off requests basically
# at every generate step
moderator=ContextSummarizer(max_tokens=300, summarization_offset=100),
)
return ex

Expand All @@ -215,13 +213,13 @@ def test_summarizer_generic_conversation(conversation_exchange_instance: Exchang
if message.text != "Summary message here":
i += 2
checkpoints = conversation_exchange_instance.checkpoint_data.checkpoints
assert conversation_exchange_instance.checkpoint_data.total_token_count == 570
assert len(checkpoints) == 10
assert len(conversation_exchange_instance.messages) == 10
assert checkpoints[0].start_index == 20
assert checkpoints[0].end_index == 20
assert conversation_exchange_instance.checkpoint_data.total_token_count == 148
assert len(checkpoints) == 4
assert len(conversation_exchange_instance.messages) == 4
assert checkpoints[0].start_index == 26
assert checkpoints[0].end_index == 26
assert checkpoints[-1].start_index == 29
assert checkpoints[-1].end_index == 29
assert conversation_exchange_instance.checkpoint_data.message_index_offset == 20
assert conversation_exchange_instance.provider.summarized_count == 12
assert conversation_exchange_instance.checkpoint_data.message_index_offset == 26
assert conversation_exchange_instance.provider.summarized_count == 10
assert conversation_exchange_instance.moderator.system_prompt_token_count == 100
2 changes: 1 addition & 1 deletion tests/test_truncate.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,5 +128,5 @@ def test_truncate_on_generic_conversation(conversation_exchange_instance: Exchan
if message.text != "Summary message here":
i += 2
# ensure the total token count is not anything exhorbitant
assert conversation_exchange_instance.checkpoint_data.total_token_count < 700
assert conversation_exchange_instance.checkpoint_data.total_token_count < 500
assert conversation_exchange_instance.moderator.system_prompt_token_count == 100