From d1739d0e123e3883d9441070cc7dce38206ec76e Mon Sep 17 00:00:00 2001 From: Matt Zhou Date: Thu, 3 Oct 2024 11:25:20 -0700 Subject: [PATCH] Fix small benchmark bugs --- letta/benchmark/benchmark.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/letta/benchmark/benchmark.py b/letta/benchmark/benchmark.py index 4031d4a7c2..7109210e9e 100644 --- a/letta/benchmark/benchmark.py +++ b/letta/benchmark/benchmark.py @@ -2,11 +2,11 @@ import time import uuid -from typing import Annotated +from typing import Annotated, Union import typer -from letta import create_client +from letta import LocalClient, RESTClient, create_client from letta.benchmark.constants import HUMAN, PERSONA, PROMPTS, TRIES from letta.config import LettaConfig @@ -17,11 +17,13 @@ app = typer.Typer() -def send_message(message: str, agent_id, turn: int, fn_type: str, print_msg: bool = False, n_tries: int = TRIES): +def send_message( + client: Union[LocalClient, RESTClient], message: str, agent_id, turn: int, fn_type: str, print_msg: bool = False, n_tries: int = TRIES +): try: print_msg = f"\t-> Now running {fn_type}. Progress: {turn}/{n_tries}" print(print_msg, end="\r", flush=True) - response = client.user_message(agent_id=agent_id, message=message, return_token_count=True) + response = client.user_message(agent_id=agent_id, message=message) if turn + 1 == n_tries: print(" " * len(print_msg), end="\r", flush=True) @@ -65,7 +67,7 @@ def bench( agent_id = agent.id result, msg = send_message( - message=message, agent_id=agent_id, turn=i, fn_type=fn_type, print_msg=print_messages, n_tries=n_tries + client=client, message=message, agent_id=agent_id, turn=i, fn_type=fn_type, print_msg=print_messages, n_tries=n_tries ) if print_messages: