From 02dc8cb8a737ba651d2a4e203dd312ebfa303275 Mon Sep 17 00:00:00 2001 From: cpacker Date: Mon, 15 Apr 2024 20:26:48 -0700 Subject: [PATCH 1/9] add streaming implementation (no interface handler yet) on OpenAI-compatible endpoints --- memgpt/llm_api/llm_api_tools.py | 29 ++- memgpt/llm_api/openai.py | 221 +++++++++++++++++++--- memgpt/models/chat_completion_response.py | 63 ++++++ poetry.lock | 15 +- pyproject.toml | 1 + 5 files changed, 299 insertions(+), 30 deletions(-) diff --git a/memgpt/llm_api/llm_api_tools.py b/memgpt/llm_api/llm_api_tools.py index 12444a0781..bc7510db02 100644 --- a/memgpt/llm_api/llm_api_tools.py +++ b/memgpt/llm_api/llm_api_tools.py @@ -13,7 +13,7 @@ from memgpt.data_types import AgentState, Message -from memgpt.llm_api.openai import openai_chat_completions_request +from memgpt.llm_api.openai import openai_chat_completions_request, openai_chat_completions_process_stream from memgpt.llm_api.azure_openai import azure_openai_chat_completions_request, MODEL_TO_AZURE_ENGINE from memgpt.llm_api.google_ai import ( google_ai_chat_completions_request, @@ -134,6 +134,10 @@ def create( # use tool naming? # if false, will use deprecated 'functions' style use_tool_naming=True, + # streaming? + # stream=False, + stream=True, + stream_inferface=None, ) -> ChatCompletionResponse: """Return response to chat completion with backoff""" from memgpt.utils import printd @@ -169,11 +173,24 @@ def create( function_call=function_call, user=str(agent_state.user_id), ) - return openai_chat_completions_request( - url=agent_state.llm_config.model_endpoint, # https://api.openai.com/v1 -> https://api.openai.com/v1/chat/completions - api_key=credentials.openai_key, - data=data, - ) + + if stream: + data.stream = True + input("stream") + return openai_chat_completions_process_stream( + url=agent_state.llm_config.model_endpoint, # https://api.openai.com/v1 -> https://api.openai.com/v1/chat/completions + api_key=credentials.openai_key, + chat_completion_request=data, + stream_inferface=stream_inferface, + ) + else: + input("no stream") + data.stream = False + return openai_chat_completions_request( + url=agent_state.llm_config.model_endpoint, # https://api.openai.com/v1 -> https://api.openai.com/v1/chat/completions + api_key=credentials.openai_key, + chat_completion_request=data, + ) # azure elif agent_state.llm_config.model_endpoint_type == "azure": diff --git a/memgpt/llm_api/openai.py b/memgpt/llm_api/openai.py index f24b61b8fe..d33ab881d3 100644 --- a/memgpt/llm_api/openai.py +++ b/memgpt/llm_api/openai.py @@ -1,11 +1,25 @@ import requests -import time -from typing import Union, Optional +import json +import httpx +from httpx_sse import connect_sse +from httpx_sse._exceptions import SSEError +from typing import Union, Optional, Generator -from memgpt.models.chat_completion_response import ChatCompletionResponse +from memgpt.models.chat_completion_response import ( + ChatCompletionResponse, + Choice, + Message, + ToolCall, + FunctionCall, + UsageStatistics, + ChatCompletionChunkResponse, +) from memgpt.models.chat_completion_request import ChatCompletionRequest from memgpt.models.embedding_response import EmbeddingResponse -from memgpt.utils import smart_urljoin +from memgpt.utils import smart_urljoin, get_utc_time + + +OPENAI_SSE_DONE = "[DONE]" def openai_get_model_list(url: str, api_key: Union[str, None], fix_url: Optional[bool] = False) -> dict: @@ -58,13 +72,136 @@ def openai_get_model_list(url: str, api_key: Union[str, None], fix_url: Optional raise e -def openai_chat_completions_request(url: str, api_key: str, data: ChatCompletionRequest) -> ChatCompletionResponse: - """https://platform.openai.com/docs/guides/text-generation?lang=curl""" +def openai_chat_completions_process_stream( + url: str, + api_key: str, + chat_completion_request: ChatCompletionRequest, + stream_inferface: Optional[str] = None, +) -> ChatCompletionResponse: + """Process a streaming completion response, and return a ChatCompletionRequest at the end. + + To "stream" the response in MemGPT, we want to call a streaming-compatible interface function + on the chunks received from the OpenAI-compatible server POST SSE response. + """ + assert chat_completion_request.stream == True + + chat_completion_response = ChatCompletionResponse( + id="", # NOTE: requires overwrite + choices=[], + created=get_utc_time(), + model=chat_completion_request.model, + usage=UsageStatistics( + completion_tokens=0, + prompt_tokens=0, + total_tokens=0, + ), + ) + + TEMP_STREAM_FINISH_REASON = "temp_null" + TEMP_STREAM_TOOL_CALL_ID = "temp_id" + for chunk_idx, chat_completion_chunk in enumerate( + openai_chat_completions_request(url=url, api_key=api_key, chat_completion_request=chat_completion_request) + ): + assert isinstance(chat_completion_chunk, ChatCompletionChunkResponse), type(chat_completion_chunk) + # stream_inferface.process(chat_completion_chunk) + print(chat_completion_chunk) + + if chunk_idx == 0: + # initialize the choice objects which we will increment with the deltas + num_choices = len(chat_completion_chunk.choices) + assert num_choices > 0 + chat_completion_response.choices = [ + Choice( + finish_reason=TEMP_STREAM_FINISH_REASON, # NOTE: needs to be ovrerwritten + index=i, + message=Message( + role="assistant", + ), + ) + for i in range(len(chat_completion_chunk.choices)) + ] + + # add the choice delta + assert len(chat_completion_chunk.choices) == len(chat_completion_response.choices), chat_completion_chunk + for chunk_choice in chat_completion_chunk.choices: + if chunk_choice.finish_reason is not None: + chat_completion_response.choices[chunk_choice.index].finish_reason = chunk_choice.finish_reason + + if chunk_choice.logprobs is not None: + chat_completion_response.choices[chunk_choice.index].logprobs = chunk_choice.logprobs + + accum_message = chat_completion_response.choices[chunk_choice.index].message + message_delta = chunk_choice.delta + + if message_delta.content is not None: + content_delta = message_delta.content + if accum_message.content is None: + accum_message.content = content_delta + else: + accum_message.content += content_delta + + if message_delta.tool_calls is not None: + tool_calls_delta = message_delta.tool_calls + + # If this is the first tool call showing up in a chunk, initialize the list with it + if accum_message.tool_calls is None: + accum_message.tool_calls = [ + ToolCall(id=TEMP_STREAM_TOOL_CALL_ID, function=FunctionCall(name="", arguments="")) + for _ in range(len(tool_calls_delta)) + ] + + for tool_call_delta in tool_calls_delta: + if tool_call_delta.id is not None: + # TODO assert that we're not overwriting? + # TODO += instead of =? + accum_message.tool_calls[tool_call_delta.index].id = tool_call_delta.id + if tool_call_delta.function is not None: + if tool_call_delta.function.name is not None: + # TODO assert that we're not overwriting? + # TODO += instead of =? + accum_message.tool_calls[tool_call_delta.index].function.name = tool_call_delta.function.name + if tool_call_delta.function.arguments is not None: + accum_message.tool_calls[tool_call_delta.index].function.arguments += tool_call_delta.function.arguments + + if message_delta.function_call is not None: + raise NotImplementedError(f"Old function_call style not support with stream=True") + + # overwrite response fields based on latest chunk + chat_completion_response.id = chat_completion_chunk.id + chat_completion_response.system_fingerprint = chat_completion_chunk.system_fingerprint + chat_completion_response.created = chat_completion_chunk.created + chat_completion_response.model = chat_completion_chunk.model + + # increment chunk counter + chunk_idx += 1 + + # compute token usage before returning + # TODO + print("choices=", chat_completion_response.choices) + + return chat_completion_response + + +def openai_chat_completions_request( + url: str, + api_key: str, + chat_completion_request: ChatCompletionRequest, +) -> Union[ChatCompletionResponse, Generator[ChatCompletionChunkResponse, None, None]]: + """Send a ChatCompletion request to an OpenAI-compatible server + + If request.stream == True, will yield ChatCompletionChunkResponses + If request.stream == False, will return a ChatCompletionResponse + + https://platform.openai.com/docs/guides/text-generation?lang=curl + """ from memgpt.utils import printd url = smart_urljoin(url, "chat/completions") headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"} - data = data.model_dump(exclude_none=True) + data = chat_completion_request.model_dump(exclude_none=True) + + # import json + # print(json.dumps(data, indent=2)) # If functions == None, strip from the payload if "functions" in data and data["functions"] is None: @@ -77,23 +214,63 @@ def openai_chat_completions_request(url: str, api_key: str, data: ChatCompletion printd(f"Sending request to {url}") try: - # Example code to trigger a rate limit response: - # mock_response = requests.Response() - # mock_response.status_code = 429 - # http_error = requests.exceptions.HTTPError("429 Client Error: Too Many Requests") - # http_error.response = mock_response - # raise http_error + if data["stream"] == True: - # Example code to trigger a context overflow response (for an 8k model) - # data["messages"][-1]["content"] = " ".join(["repeat after me this is not a fluke"] * 1000) + with httpx.Client() as client: + with connect_sse(client, method="POST", url=url, json=data, headers=headers) as event_source: + try: + for sse in event_source.iter_sse(): + print(sse.event, sse.data, sse.id, sse.retry) + if sse.data == OPENAI_SSE_DONE: + # print("finished") + break + else: + chunk_data = json.loads(sse.data) + # print("chunk_data::", chunk_data) + chunk_object = ChatCompletionChunkResponse(**chunk_data) + # print("chunk_object::", chunk_object) + # id=chunk_data["id"], + # choices=[ChunkChoice], + # model=chunk_data["model"], + # system_fingerprint=chunk_data["system_fingerprint"] + # ) + yield chunk_object - response = requests.post(url, headers=headers, json=data) - printd(f"response = {response}") - response.raise_for_status() # Raises HTTPError for 4XX/5XX status - response = response.json() # convert to dict from string - printd(f"response.json = {response}") - response = ChatCompletionResponse(**response) # convert to 'dot-dict' style which is the openai python client default - return response + except SSEError as e: + if "application/json" in str(e): # Check if the error is because of JSON response + response = client.post(url=url, json=data, headers=headers) # Make the request again to get the JSON response + if response.headers["Content-Type"].startswith("application/json"): + error_details = response.json() # Parse the JSON to get the error message + print("Error:", error_details) + print("Reqeust:", vars(response.request)) + else: + print("Failed to retrieve JSON error message.") + else: + print("SSEError not related to 'application/json' content type.") + + # Optionally re-raise the exception if you need to propagate it + raise e + + except Exception as e: + if event_source.response.request is not None: + print("HTTP Request:", vars(event_source.response.request)) + if event_source.response is not None: + print("HTTP Status:", event_source.response.status_code) + print("HTTP Headers:", event_source.response.headers) + # print("HTTP Body:", event_source.response.text) + print("Exception message:", str(e)) + raise e + + else: + response = requests.post(url, headers=headers, json=data) + printd(f"response = {response}") + response.raise_for_status() # Raises HTTPError for 4XX/5XX status + + response = response.json() # convert to dict from string + printd(f"response.json = {response}") + + response = ChatCompletionResponse(**response) # convert to 'dot-dict' style which is the openai python client default + return response except requests.exceptions.HTTPError as http_err: # Handle HTTP errors (e.g., response 4XX, 5XX) printd(f"Got HTTPError, exception={http_err}, payload={data}") diff --git a/memgpt/models/chat_completion_response.py b/memgpt/models/chat_completion_response.py index ebd172d368..56dab43696 100644 --- a/memgpt/models/chat_completion_response.py +++ b/memgpt/models/chat_completion_response.py @@ -55,6 +55,8 @@ class UsageStatistics(BaseModel): class ChatCompletionResponse(BaseModel): + """https://platform.openai.com/docs/api-reference/chat/object""" + id: str choices: List[Choice] created: datetime.datetime @@ -64,3 +66,64 @@ class ChatCompletionResponse(BaseModel): # object: str = Field(default="chat.completion") object: Literal["chat.completion"] = "chat.completion" usage: UsageStatistics + + +class FunctionCallDelta(BaseModel): + # arguments: Optional[str] = None + name: Optional[str] = None + arguments: str + # name: str + + +class ToolCallDelta(BaseModel): + index: int + id: Optional[str] = None + # "Currently, only function is supported" + type: Literal["function"] = "function" + # function: ToolCallFunction + function: Optional[FunctionCallDelta] = None + + +class MessageDelta(BaseModel): + """Partial delta stream of a Message + + Example ChunkResponse: + { + 'id': 'chatcmpl-9EOCkKdicNo1tiL1956kPvCnL2lLS', + 'object': 'chat.completion.chunk', + 'created': 1713216662, + 'model': 'gpt-4-0613', + 'system_fingerprint': None, + 'choices': [{ + 'index': 0, + 'delta': {'content': 'User'}, + 'logprobs': None, + 'finish_reason': None + }] + } + """ + + content: Optional[str] = None + tool_calls: Optional[List[ToolCallDelta]] = None + # role: Optional[str] = None + function_call: Optional[FunctionCallDelta] = None # Deprecated + + +class ChunkChoice(BaseModel): + finish_reason: Optional[str] = None # NOTE: when streaming will be null + index: int + delta: MessageDelta + logprobs: Optional[Dict[str, Union[List[MessageContentLogProb], None]]] = None + + +class ChatCompletionChunkResponse(BaseModel): + """https://platform.openai.com/docs/api-reference/chat/streaming""" + + id: str + choices: List[ChunkChoice] + created: datetime.datetime + model: str + # system_fingerprint: str # docs say this is mandatory, but in reality API returns None + system_fingerprint: Optional[str] = None + # object: str = Field(default="chat.completion") + object: Literal["chat.completion.chunk"] = "chat.completion.chunk" diff --git a/poetry.lock b/poetry.lock index dad0f73189..2edac8eb45 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.7.0 and should not be changed by hand. [[package]] name = "aiohttp" @@ -1524,6 +1524,17 @@ cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"] http2 = ["h2 (>=3,<5)"] socks = ["socksio (==1.*)"] +[[package]] +name = "httpx-sse" +version = "0.4.0" +description = "Consume Server-Sent Event (SSE) messages with HTTPX." +optional = false +python-versions = ">=3.8" +files = [ + {file = "httpx-sse-0.4.0.tar.gz", hash = "sha256:1e81a3a3070ce322add1d3529ed42eb5f70817f45ed6ec915ab753f961139721"}, + {file = "httpx_sse-0.4.0-py3-none-any.whl", hash = "sha256:f329af6eae57eaa2bdfd962b42524764af68075ea87370a2de920af5341e318f"}, +] + [[package]] name = "huggingface-hub" version = "0.22.2" @@ -6094,4 +6105,4 @@ server = ["fastapi", "uvicorn", "websockets"] [metadata] lock-version = "2.0" python-versions = "<3.13,>=3.10" -content-hash = "5c36931d717323eab3eea32bf383b27578ea8f3467fd230ce543af364caffa92" +content-hash = "a9635dccf8bd7d826f776e36a9d6fbc845a1b7de0586d06c6a9ce7230a5a14bc" diff --git a/pyproject.toml b/pyproject.toml index 7c9e748b82..788dae14d8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,6 +57,7 @@ llama-index-embeddings-openai = "^0.1.1" llama-index-embeddings-huggingface = {version = "^0.2.0", optional = true} llama-index-embeddings-azure-openai = "^0.1.6" python-multipart = "^0.0.9" +httpx-sse = "^0.4.0" [tool.poetry.extras] local = ["llama-index-embeddings-huggingface"] From a8684b6fef194b29d8f1396748f1456ada44d591 Mon Sep 17 00:00:00 2001 From: cpacker Date: Mon, 15 Apr 2024 21:19:04 -0700 Subject: [PATCH 2/9] add asserts for temp flags, drop stray input --- memgpt/llm_api/llm_api_tools.py | 2 -- memgpt/llm_api/openai.py | 9 +++++++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/memgpt/llm_api/llm_api_tools.py b/memgpt/llm_api/llm_api_tools.py index bc7510db02..f16f032c6e 100644 --- a/memgpt/llm_api/llm_api_tools.py +++ b/memgpt/llm_api/llm_api_tools.py @@ -176,7 +176,6 @@ def create( if stream: data.stream = True - input("stream") return openai_chat_completions_process_stream( url=agent_state.llm_config.model_endpoint, # https://api.openai.com/v1 -> https://api.openai.com/v1/chat/completions api_key=credentials.openai_key, @@ -184,7 +183,6 @@ def create( stream_inferface=stream_inferface, ) else: - input("no stream") data.stream = False return openai_chat_completions_request( url=agent_state.llm_config.model_endpoint, # https://api.openai.com/v1 -> https://api.openai.com/v1/chat/completions diff --git a/memgpt/llm_api/openai.py b/memgpt/llm_api/openai.py index d33ab881d3..bc6518edaf 100644 --- a/memgpt/llm_api/openai.py +++ b/memgpt/llm_api/openai.py @@ -175,6 +175,15 @@ def openai_chat_completions_process_stream( # increment chunk counter chunk_idx += 1 + # make sure we didn't leave temp stuff in + assert all([c.finish_reason != TEMP_STREAM_FINISH_REASON for c in chat_completion_response.choices]) + assert all( + [ + all([tc != TEMP_STREAM_TOOL_CALL_ID for tc in c.message.tool_calls]) if c.message.tool_calls else True + for c in chat_completion_response.choices + ] + ) + # compute token usage before returning # TODO print("choices=", chat_completion_response.choices) From 693372a239d9193409a064b0f04a998fe2d01108 Mon Sep 17 00:00:00 2001 From: cpacker Date: Tue, 16 Apr 2024 09:36:32 -0700 Subject: [PATCH 3/9] added a line delta streaming interface --- memgpt/llm_api/llm_api_tools.py | 3 + memgpt/llm_api/openai.py | 19 ++- memgpt/main.py | 21 ++-- memgpt/streaming_interface.py | 200 ++++++++++++++++++++++++++++++++ 4 files changed, 225 insertions(+), 18 deletions(-) create mode 100644 memgpt/streaming_interface.py diff --git a/memgpt/llm_api/llm_api_tools.py b/memgpt/llm_api/llm_api_tools.py index f16f032c6e..c3ab488f2b 100644 --- a/memgpt/llm_api/llm_api_tools.py +++ b/memgpt/llm_api/llm_api_tools.py @@ -176,6 +176,9 @@ def create( if stream: data.stream = True + from memgpt.streaming_interface import StreamingCLIInterface + + stream_inferface = StreamingCLIInterface() return openai_chat_completions_process_stream( url=agent_state.llm_config.model_endpoint, # https://api.openai.com/v1 -> https://api.openai.com/v1/chat/completions api_key=credentials.openai_key, diff --git a/memgpt/llm_api/openai.py b/memgpt/llm_api/openai.py index bc6518edaf..61d2661158 100644 --- a/memgpt/llm_api/openai.py +++ b/memgpt/llm_api/openai.py @@ -17,6 +17,8 @@ from memgpt.models.chat_completion_request import ChatCompletionRequest from memgpt.models.embedding_response import EmbeddingResponse from memgpt.utils import smart_urljoin, get_utc_time +from memgpt.interface import AgentInterface +from memgpt.streaming_interface import AgentStreamingInterface OPENAI_SSE_DONE = "[DONE]" @@ -76,7 +78,7 @@ def openai_chat_completions_process_stream( url: str, api_key: str, chat_completion_request: ChatCompletionRequest, - stream_inferface: Optional[str] = None, + stream_inferface: Optional[AgentStreamingInterface] = None, ) -> ChatCompletionResponse: """Process a streaming completion response, and return a ChatCompletionRequest at the end. @@ -97,14 +99,19 @@ def openai_chat_completions_process_stream( ), ) + if stream_inferface: + stream_inferface.stream_start() + TEMP_STREAM_FINISH_REASON = "temp_null" TEMP_STREAM_TOOL_CALL_ID = "temp_id" for chunk_idx, chat_completion_chunk in enumerate( openai_chat_completions_request(url=url, api_key=api_key, chat_completion_request=chat_completion_request) ): assert isinstance(chat_completion_chunk, ChatCompletionChunkResponse), type(chat_completion_chunk) - # stream_inferface.process(chat_completion_chunk) - print(chat_completion_chunk) + # print(chat_completion_chunk) + + if stream_inferface: + stream_inferface.process_chunk(chat_completion_chunk) if chunk_idx == 0: # initialize the choice objects which we will increment with the deltas @@ -175,6 +182,10 @@ def openai_chat_completions_process_stream( # increment chunk counter chunk_idx += 1 + # TODO change to a finally block + if stream_inferface: + stream_inferface.stream_end() + # make sure we didn't leave temp stuff in assert all([c.finish_reason != TEMP_STREAM_FINISH_REASON for c in chat_completion_response.choices]) assert all( @@ -229,7 +240,7 @@ def openai_chat_completions_request( with connect_sse(client, method="POST", url=url, json=data, headers=headers) as event_source: try: for sse in event_source.iter_sse(): - print(sse.event, sse.data, sse.id, sse.retry) + # printd(sse.event, sse.data, sse.id, sse.retry) if sse.data == OPENAI_SSE_DONE: # print("finished") break diff --git a/memgpt/main.py b/memgpt/main.py index feb83b6e91..f7b4375807 100644 --- a/memgpt/main.py +++ b/memgpt/main.py @@ -12,7 +12,6 @@ console = Console() -from memgpt.agent import save_agent from memgpt.agent_store.storage import StorageConnector, TableType from memgpt.interface import CLIInterface as interface # for printing to terminal from memgpt.config import MemGPTConfig @@ -194,9 +193,7 @@ def run_agent_loop(memgpt_agent, config: MemGPTConfig, first, ms: MetadataStore, else: print(f"Popping last {pop_amount} messages from stack") for _ in range(min(pop_amount, len(memgpt_agent.messages))): - memgpt_agent._messages.pop() - # Persist the state - save_agent(agent=memgpt_agent, ms=ms) + memgpt_agent.messages.pop() continue elif user_input.lower() == "/retry": @@ -218,13 +215,7 @@ def run_agent_loop(memgpt_agent, config: MemGPTConfig, first, ms: MetadataStore, for x in range(len(memgpt_agent.messages) - 1, 0, -1): if memgpt_agent.messages[x].get("role") == "assistant": text = user_input[len("/rethink ") :].strip() - - # Do the /rethink-ing - message_obj = memgpt_agent._messages[x] - message_obj.text = text - - # To persist to the database, all we need to do is "re-insert" into recall memory - memgpt_agent.persistence_manager.recall_memory.storage.update(record=message_obj) + memgpt_agent.messages[x].update({"content": text}) break continue @@ -376,9 +367,11 @@ def process_agent_step(user_message, no_verify): new_messages, user_message, skip_next_user_input = process_agent_step(user_message, no_verify) break else: - with console.status("[bold cyan]Thinking...") as status: - new_messages, user_message, skip_next_user_input = process_agent_step(user_message, no_verify) - break + # with console.status("[bold cyan]Thinking...") as status: + # new_messages, user_message, skip_next_user_input = process_agent_step(user_message, no_verify) + # break + new_messages, user_message, skip_next_user_input = process_agent_step(user_message, no_verify) + break except KeyboardInterrupt: print("User interrupt occurred.") retry = questionary.confirm("Retry agent.step()?").ask() diff --git a/memgpt/streaming_interface.py b/memgpt/streaming_interface.py new file mode 100644 index 0000000000..d344e9d829 --- /dev/null +++ b/memgpt/streaming_interface.py @@ -0,0 +1,200 @@ +from abc import ABC, abstractmethod +import json +import re +from typing import List, Optional + +from colorama import Fore, Style, init + +from memgpt.utils import printd +from memgpt.constants import CLI_WARNING_PREFIX, JSON_LOADS_STRICT +from memgpt.data_types import Message +from memgpt.models.chat_completion_response import ChatCompletionChunkResponse +from memgpt.interface import AgentInterface, CLIInterface + +init(autoreset=True) + +# DEBUG = True # puts full message outputs in the terminal +DEBUG = False # only dumps important messages in the terminal + +STRIP_UI = False + + +class AgentStreamingInterface(ABC): + """Interfaces handle MemGPT-related events (observer pattern) + + The 'msg' args provides the scoped message, and the optional Message arg can provide additional metadata. + """ + + @abstractmethod + def user_message(self, msg: str, msg_obj: Optional[Message] = None): + """MemGPT receives a user message""" + raise NotImplementedError + + @abstractmethod + def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None): + """MemGPT generates some internal monologue""" + raise NotImplementedError + + @abstractmethod + def assistant_message(self, msg: str, msg_obj: Optional[Message] = None): + """MemGPT uses send_message""" + raise NotImplementedError + + @abstractmethod + def function_message(self, msg: str, msg_obj: Optional[Message] = None): + """MemGPT calls a function""" + raise NotImplementedError + + @abstractmethod + def process_chunk(self, chunk: ChatCompletionChunkResponse): + """Process a streaming chunk from an OpenAI-compatible server""" + raise NotImplementedError + + @abstractmethod + def stream_start(self): + """Any setup required before streaming begins""" + raise NotImplementedError + + @abstractmethod + def stream_end(self): + """Any cleanup required after streaming ends""" + raise NotImplementedError + + +class StreamingCLIInterface(AgentStreamingInterface): + """Version of the CLI interface that attaches to a stream generator and prints along the way. + + When a chunk is received, we write the delta to the buffer. If the buffer type has changed, + we write out a newline + set the formatting for the new line. + + The two buffer types are: + (1) content (inner thoughts) + (2) tool_calls (function calling) + + NOTE: this assumes that the deltas received in the chunks are in-order, e.g. + that once 'content' deltas stop streaming, they won't be received again. See notes + on alternative version of the StreamingCLIInterface that does not have this same problem below: + + An alternative implementation could instead maintain the partial message state, and on each + process chunk (1) update the partial message state, (2) refresh/rewrite the state to the screen. + """ + + # CLIInterface is static/stateless + nonstreaming_interface = CLIInterface() + + def __init__(self): + """The streaming CLI interface state for determining which buffer is currently being written to""" + + self.streaming_buffer_type = None + + def _flush(self): + pass + + def process_chunk(self, chunk: ChatCompletionChunkResponse): + assert len(chunk.choices) == 1, chunk + + message_delta = chunk.choices[0].delta + + # Starting a new buffer line + if not self.streaming_buffer_type: + assert not ( + message_delta.content is not None and message_delta.tool_calls is not None and len(message_delta.tool_calls) + ), f"Error: got both content and tool_calls in message stream\n{message_delta}" + + if message_delta.content is not None: + # Write out the prefix for inner thoughts + print("Inner thoughts: ", end="", flush=True) + elif message_delta.tool_calls is not None: + assert len(message_delta.tool_calls) == 1, f"Error: got more than one tool call in response\n{message_delta}" + # Write out the prefix for function calling + print("Calling function: ", end="", flush=True) + + # Potentially switch/flush a buffer line + else: + pass + + # Write out the delta + if message_delta.content is not None: + if self.streaming_buffer_type and self.streaming_buffer_type != "content": + print() + self.streaming_buffer_type = "content" + + # Simple, just write out to the buffer + print(message_delta.content, end="", flush=True) + + elif message_delta.tool_calls is not None: + if self.streaming_buffer_type and self.streaming_buffer_type != "tool_calls": + print() + self.streaming_buffer_type = "tool_calls" + + assert len(message_delta.tool_calls) == 1, f"Error: got more than one tool call in response\n{message_delta}" + function_call = message_delta.tool_calls[0].function + + # Slightly more complex - want to write parameters in a certain way (paren-style) + # function_name(function_args) + if function_call.name: + # NOTE: need to account for closing the brace later + print(f"{function_call.name}(", end="", flush=True) + if function_call.arguments: + print(function_call.arguments, end="", flush=True) + + def stream_start(self): + # should be handled by stream_end(), but just in case + self.streaming_buffer_type = None + + def stream_end(self): + if self.streaming_buffer_type is not None: + # TODO: should have a separate self.tool_call_open_paren flag + if self.streaming_buffer_type == "tool_calls": + print(")", end="", flush=True) + + print() # newline to move the cursor + self.streaming_buffer_type = None # reset buffer tracker + + @staticmethod + def important_message(msg: str): + StreamingCLIInterface.nonstreaming_interface(msg) + + @staticmethod + def warning_message(msg: str): + StreamingCLIInterface.nonstreaming_interface(msg) + + @staticmethod + def internal_monologue(msg: str, msg_obj: Optional[Message] = None): + StreamingCLIInterface.nonstreaming_interface(msg, msg_obj) + + @staticmethod + def assistant_message(msg: str, msg_obj: Optional[Message] = None): + StreamingCLIInterface.nonstreaming_interface(msg, msg_obj) + + @staticmethod + def memory_message(msg: str, msg_obj: Optional[Message] = None): + StreamingCLIInterface.nonstreaming_interface(msg, msg_obj) + + @staticmethod + def system_message(msg: str, msg_obj: Optional[Message] = None): + StreamingCLIInterface.nonstreaming_interface(msg, msg_obj) + + @staticmethod + def user_message(msg: str, msg_obj: Optional[Message] = None, raw: bool = False, dump: bool = False, debug: bool = DEBUG): + StreamingCLIInterface.nonstreaming_interface(msg, msg_obj) + + @staticmethod + def function_message(msg: str, msg_obj: Optional[Message] = None, debug: bool = DEBUG): + StreamingCLIInterface.nonstreaming_interface(msg, msg_obj) + + @staticmethod + def print_messages(message_sequence: List[Message], dump=False): + StreamingCLIInterface.nonstreaming_interface(message_sequence, dump) + + @staticmethod + def print_messages_simple(message_sequence: List[Message]): + StreamingCLIInterface.nonstreaming_interface.print_messages_simple(message_sequence) + + @staticmethod + def print_messages_raw(message_sequence: List[Message]): + StreamingCLIInterface.nonstreaming_interface.print_messages_raw(message_sequence) + + @staticmethod + def step_yield(): + pass From 037396f1fef3b1fb90933afc93d6e79022d5f8c1 Mon Sep 17 00:00:00 2001 From: cpacker Date: Tue, 16 Apr 2024 10:05:56 -0700 Subject: [PATCH 4/9] put stream parsing in try/except/finally --- memgpt/llm_api/openai.py | 164 ++++++++++++++++++++------------------- 1 file changed, 85 insertions(+), 79 deletions(-) diff --git a/memgpt/llm_api/openai.py b/memgpt/llm_api/openai.py index 61d2661158..be68045c9b 100644 --- a/memgpt/llm_api/openai.py +++ b/memgpt/llm_api/openai.py @@ -104,87 +104,93 @@ def openai_chat_completions_process_stream( TEMP_STREAM_FINISH_REASON = "temp_null" TEMP_STREAM_TOOL_CALL_ID = "temp_id" - for chunk_idx, chat_completion_chunk in enumerate( - openai_chat_completions_request(url=url, api_key=api_key, chat_completion_request=chat_completion_request) - ): - assert isinstance(chat_completion_chunk, ChatCompletionChunkResponse), type(chat_completion_chunk) - # print(chat_completion_chunk) - if stream_inferface: - stream_inferface.process_chunk(chat_completion_chunk) - - if chunk_idx == 0: - # initialize the choice objects which we will increment with the deltas - num_choices = len(chat_completion_chunk.choices) - assert num_choices > 0 - chat_completion_response.choices = [ - Choice( - finish_reason=TEMP_STREAM_FINISH_REASON, # NOTE: needs to be ovrerwritten - index=i, - message=Message( - role="assistant", - ), - ) - for i in range(len(chat_completion_chunk.choices)) - ] - - # add the choice delta - assert len(chat_completion_chunk.choices) == len(chat_completion_response.choices), chat_completion_chunk - for chunk_choice in chat_completion_chunk.choices: - if chunk_choice.finish_reason is not None: - chat_completion_response.choices[chunk_choice.index].finish_reason = chunk_choice.finish_reason - - if chunk_choice.logprobs is not None: - chat_completion_response.choices[chunk_choice.index].logprobs = chunk_choice.logprobs - - accum_message = chat_completion_response.choices[chunk_choice.index].message - message_delta = chunk_choice.delta - - if message_delta.content is not None: - content_delta = message_delta.content - if accum_message.content is None: - accum_message.content = content_delta - else: - accum_message.content += content_delta - - if message_delta.tool_calls is not None: - tool_calls_delta = message_delta.tool_calls - - # If this is the first tool call showing up in a chunk, initialize the list with it - if accum_message.tool_calls is None: - accum_message.tool_calls = [ - ToolCall(id=TEMP_STREAM_TOOL_CALL_ID, function=FunctionCall(name="", arguments="")) - for _ in range(len(tool_calls_delta)) - ] - - for tool_call_delta in tool_calls_delta: - if tool_call_delta.id is not None: - # TODO assert that we're not overwriting? - # TODO += instead of =? - accum_message.tool_calls[tool_call_delta.index].id = tool_call_delta.id - if tool_call_delta.function is not None: - if tool_call_delta.function.name is not None: + try: + for chunk_idx, chat_completion_chunk in enumerate( + openai_chat_completions_request(url=url, api_key=api_key, chat_completion_request=chat_completion_request) + ): + assert isinstance(chat_completion_chunk, ChatCompletionChunkResponse), type(chat_completion_chunk) + # print(chat_completion_chunk) + + if stream_inferface: + stream_inferface.process_chunk(chat_completion_chunk) + + if chunk_idx == 0: + # initialize the choice objects which we will increment with the deltas + num_choices = len(chat_completion_chunk.choices) + assert num_choices > 0 + chat_completion_response.choices = [ + Choice( + finish_reason=TEMP_STREAM_FINISH_REASON, # NOTE: needs to be ovrerwritten + index=i, + message=Message( + role="assistant", + ), + ) + for i in range(len(chat_completion_chunk.choices)) + ] + + # add the choice delta + assert len(chat_completion_chunk.choices) == len(chat_completion_response.choices), chat_completion_chunk + for chunk_choice in chat_completion_chunk.choices: + if chunk_choice.finish_reason is not None: + chat_completion_response.choices[chunk_choice.index].finish_reason = chunk_choice.finish_reason + + if chunk_choice.logprobs is not None: + chat_completion_response.choices[chunk_choice.index].logprobs = chunk_choice.logprobs + + accum_message = chat_completion_response.choices[chunk_choice.index].message + message_delta = chunk_choice.delta + + if message_delta.content is not None: + content_delta = message_delta.content + if accum_message.content is None: + accum_message.content = content_delta + else: + accum_message.content += content_delta + + if message_delta.tool_calls is not None: + tool_calls_delta = message_delta.tool_calls + + # If this is the first tool call showing up in a chunk, initialize the list with it + if accum_message.tool_calls is None: + accum_message.tool_calls = [ + ToolCall(id=TEMP_STREAM_TOOL_CALL_ID, function=FunctionCall(name="", arguments="")) + for _ in range(len(tool_calls_delta)) + ] + + for tool_call_delta in tool_calls_delta: + if tool_call_delta.id is not None: # TODO assert that we're not overwriting? # TODO += instead of =? - accum_message.tool_calls[tool_call_delta.index].function.name = tool_call_delta.function.name - if tool_call_delta.function.arguments is not None: - accum_message.tool_calls[tool_call_delta.index].function.arguments += tool_call_delta.function.arguments - - if message_delta.function_call is not None: - raise NotImplementedError(f"Old function_call style not support with stream=True") - - # overwrite response fields based on latest chunk - chat_completion_response.id = chat_completion_chunk.id - chat_completion_response.system_fingerprint = chat_completion_chunk.system_fingerprint - chat_completion_response.created = chat_completion_chunk.created - chat_completion_response.model = chat_completion_chunk.model - - # increment chunk counter - chunk_idx += 1 - - # TODO change to a finally block - if stream_inferface: - stream_inferface.stream_end() + accum_message.tool_calls[tool_call_delta.index].id = tool_call_delta.id + if tool_call_delta.function is not None: + if tool_call_delta.function.name is not None: + # TODO assert that we're not overwriting? + # TODO += instead of =? + accum_message.tool_calls[tool_call_delta.index].function.name = tool_call_delta.function.name + if tool_call_delta.function.arguments is not None: + accum_message.tool_calls[tool_call_delta.index].function.arguments += tool_call_delta.function.arguments + + if message_delta.function_call is not None: + raise NotImplementedError(f"Old function_call style not support with stream=True") + + # overwrite response fields based on latest chunk + chat_completion_response.id = chat_completion_chunk.id + chat_completion_response.system_fingerprint = chat_completion_chunk.system_fingerprint + chat_completion_response.created = chat_completion_chunk.created + chat_completion_response.model = chat_completion_chunk.model + + # increment chunk counter + chunk_idx += 1 + except Exception as e: + if stream_inferface: + stream_inferface.stream_end() + print(f"Parsing ChatCompletion stream failed with error:\n{str(e)}") + raise e + finally: + if stream_inferface: + stream_inferface.stream_end() # make sure we didn't leave temp stuff in assert all([c.finish_reason != TEMP_STREAM_FINISH_REASON for c in chat_completion_response.choices]) @@ -197,7 +203,7 @@ def openai_chat_completions_process_stream( # compute token usage before returning # TODO - print("choices=", chat_completion_response.choices) + # print("choices=", chat_completion_response.choices) return chat_completion_response From 4e8f7a1252032854e0ad1efceb89d95a774b630c Mon Sep 17 00:00:00 2001 From: cpacker Date: Tue, 16 Apr 2024 16:20:21 -0700 Subject: [PATCH 5/9] added proper token counting to streaming requests --- memgpt/llm_api/openai.py | 47 ++++++++++--- memgpt/local_llm/utils.py | 143 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 181 insertions(+), 9 deletions(-) diff --git a/memgpt/llm_api/openai.py b/memgpt/llm_api/openai.py index be68045c9b..4a298e87d4 100644 --- a/memgpt/llm_api/openai.py +++ b/memgpt/llm_api/openai.py @@ -17,6 +17,7 @@ from memgpt.models.chat_completion_request import ChatCompletionRequest from memgpt.models.embedding_response import EmbeddingResponse from memgpt.utils import smart_urljoin, get_utc_time +from memgpt.local_llm.utils import num_tokens_from_messages, num_tokens_from_functions from memgpt.interface import AgentInterface from memgpt.streaming_interface import AgentStreamingInterface @@ -87,24 +88,48 @@ def openai_chat_completions_process_stream( """ assert chat_completion_request.stream == True + # Count the prompt tokens + # TODO move to post-request? + chat_history = [m.model_dump(exclude_none=True) for m in chat_completion_request.messages] + print(chat_history) + + prompt_tokens = num_tokens_from_messages( + messages=chat_history, + model=chat_completion_request.model, + ) + # We also need to add the cost of including the functions list to the input prompt + if chat_completion_request.tools is not None: + assert chat_completion_request.functions is None + prompt_tokens += num_tokens_from_functions( + functions=[t.function.model_dump() for t in chat_completion_request.tools], + model=chat_completion_request.model, + ) + elif chat_completion_request.functions is not None: + assert chat_completion_request.tools is None + prompt_tokens += num_tokens_from_functions( + functions=[f.model_dump() for f in chat_completion_request.functions], + model=chat_completion_request.model, + ) + + TEMP_STREAM_RESPONSE_ID = "temp_id" + TEMP_STREAM_FINISH_REASON = "temp_null" + TEMP_STREAM_TOOL_CALL_ID = "temp_id" chat_completion_response = ChatCompletionResponse( - id="", # NOTE: requires overwrite + id=TEMP_STREAM_RESPONSE_ID, choices=[], created=get_utc_time(), model=chat_completion_request.model, usage=UsageStatistics( completion_tokens=0, - prompt_tokens=0, - total_tokens=0, + prompt_tokens=prompt_tokens, + total_tokens=prompt_tokens, ), ) if stream_inferface: stream_inferface.stream_start() - TEMP_STREAM_FINISH_REASON = "temp_null" - TEMP_STREAM_TOOL_CALL_ID = "temp_id" - + n_chunks = 0 # approx == n_tokens try: for chunk_idx, chat_completion_chunk in enumerate( openai_chat_completions_request(url=url, api_key=api_key, chat_completion_request=chat_completion_request) @@ -182,7 +207,8 @@ def openai_chat_completions_process_stream( chat_completion_response.model = chat_completion_chunk.model # increment chunk counter - chunk_idx += 1 + n_chunks += 1 + except Exception as e: if stream_inferface: stream_inferface.stream_end() @@ -200,11 +226,14 @@ def openai_chat_completions_process_stream( for c in chat_completion_response.choices ] ) + assert chat_completion_response.id != TEMP_STREAM_RESPONSE_ID # compute token usage before returning - # TODO - # print("choices=", chat_completion_response.choices) + # TODO try actually computing the #tokens instead of assuming the chunks is the same + chat_completion_response.usage.completion_tokens = n_chunks + chat_completion_response.usage.total_tokens = prompt_tokens + n_chunks + # printd(chat_completion_response) return chat_completion_response diff --git a/memgpt/local_llm/utils.py b/memgpt/local_llm/utils.py index 8306e91b73..27db5c9e8d 100644 --- a/memgpt/local_llm/utils.py +++ b/memgpt/local_llm/utils.py @@ -1,6 +1,7 @@ import os import requests import tiktoken +from typing import List import memgpt.local_llm.llm_chat_completion_wrappers.airoboros as airoboros import memgpt.local_llm.llm_chat_completion_wrappers.dolphin as dolphin @@ -74,6 +75,148 @@ def count_tokens(s: str, model: str = "gpt-4") -> int: return len(encoding.encode(s)) +def num_tokens_from_functions(functions: List[dict], model: str = "gpt-4"): + """Return the number of tokens used by a list of functions. + + Copied from https://community.openai.com/t/how-to-calculate-the-tokens-when-using-function-call/266573/11 + """ + try: + encoding = tiktoken.encoding_for_model(model) + except KeyError: + print("Warning: model not found. Using cl100k_base encoding.") + encoding = tiktoken.get_encoding("cl100k_base") + + num_tokens = 0 + for function in functions: + function_tokens = len(encoding.encode(function["name"])) + function_tokens += len(encoding.encode(function["description"])) + + if "parameters" in function: + parameters = function["parameters"] + if "properties" in parameters: + for propertiesKey in parameters["properties"]: + function_tokens += len(encoding.encode(propertiesKey)) + v = parameters["properties"][propertiesKey] + for field in v: + if field == "type": + function_tokens += 2 + function_tokens += len(encoding.encode(v["type"])) + elif field == "description": + function_tokens += 2 + function_tokens += len(encoding.encode(v["description"])) + elif field == "enum": + function_tokens -= 3 + for o in v["enum"]: + function_tokens += 3 + function_tokens += len(encoding.encode(o)) + else: + print(f"Warning: not supported field {field}") + function_tokens += 11 + + num_tokens += function_tokens + + num_tokens += 12 + return num_tokens + + +def num_tokens_from_tool_calls(tool_calls: List[dict], model: str = "gpt-4"): + """Based on above code (num_tokens_from_functions). + + Example to encode: + [{ + 'id': '8b6707cf-2352-4804-93db-0423f', + 'type': 'function', + 'function': { + 'name': 'send_message', + 'arguments': '{\n "message": "More human than human is our motto."\n}' + } + }] + """ + try: + encoding = tiktoken.encoding_for_model(model) + except KeyError: + print("Warning: model not found. Using cl100k_base encoding.") + encoding = tiktoken.get_encoding("cl100k_base") + + num_tokens = 0 + for tool_call in tool_calls: + function_tokens = len(encoding.encode(tool_call["id"])) + function_tokens += 2 + len(encoding.encode(tool_call["type"])) + function_tokens += 2 + len(encoding.encode(tool_call["function"]["name"])) + function_tokens += 2 + len(encoding.encode(tool_call["function"]["arguments"])) + + num_tokens += function_tokens + + # TODO adjust? + num_tokens += 12 + return num_tokens + + +def num_tokens_from_messages(messages: List[dict], model: str = "gpt-4") -> int: + """Return the number of tokens used by a list of messages. + + From: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb + + For counting tokens in function calling RESPONSES, see: + https://hmarr.com/blog/counting-openai-tokens/, https://github.com/hmarr/openai-chat-tokens + + For counting tokens in function calling REQUESTS, see: + https://community.openai.com/t/how-to-calculate-the-tokens-when-using-function-call/266573/11 + """ + try: + encoding = tiktoken.encoding_for_model(model) + except KeyError: + print("Warning: model not found. Using cl100k_base encoding.") + encoding = tiktoken.get_encoding("cl100k_base") + if model in { + "gpt-3.5-turbo-0613", + "gpt-3.5-turbo-16k-0613", + "gpt-4-0314", + "gpt-4-32k-0314", + "gpt-4-0613", + "gpt-4-32k-0613", + }: + tokens_per_message = 3 + tokens_per_name = 1 + elif model == "gpt-3.5-turbo-0301": + tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n + tokens_per_name = -1 # if there's a name, the role is omitted + elif "gpt-3.5-turbo" in model: + print("Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.") + return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0613") + elif "gpt-4" in model: + print("Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.") + return num_tokens_from_messages(messages, model="gpt-4-0613") + else: + raise NotImplementedError( + f"""num_tokens_from_messages() is not implemented for model {model}. See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens.""" + ) + num_tokens = 0 + for message in messages: + num_tokens += tokens_per_message + for key, value in message.items(): + try: + + if isinstance(value, list) and key == "tool_calls": + num_tokens += num_tokens_from_tool_calls(tool_calls=value, model=model) + # special case for tool calling (list) + # num_tokens += len(encoding.encode(value["name"])) + # num_tokens += len(encoding.encode(value["arguments"])) + + else: + num_tokens += len(encoding.encode(value)) + + if key == "name": + num_tokens += tokens_per_name + + except TypeError as e: + print(f"tiktoken encoding failed on: {value}") + raise e + + num_tokens += 3 # every reply is primed with <|start|>assistant<|message|> + return num_tokens + + def get_available_wrappers() -> dict: return { "experimental-wrapper-neural-chat-grammar-noforce": configurable_wrapper.ConfigurableJSONWrapper( From 98024951effdbd1fedf4c8a1c9241095bce7bb7c Mon Sep 17 00:00:00 2001 From: cpacker Date: Tue, 16 Apr 2024 21:35:59 -0700 Subject: [PATCH 6/9] working --- memgpt/cli/cli.py | 4 +- memgpt/llm_api/llm_api_tools.py | 5 +- memgpt/llm_api/openai.py | 16 ++- memgpt/local_llm/utils.py | 8 +- memgpt/main.py | 8 +- memgpt/streaming_interface.py | 241 +++++++++++++++++++++++++++++++- 6 files changed, 262 insertions(+), 20 deletions(-) diff --git a/memgpt/cli/cli.py b/memgpt/cli/cli.py index fe1361c590..2d2e4185f9 100644 --- a/memgpt/cli/cli.py +++ b/memgpt/cli/cli.py @@ -13,7 +13,9 @@ import questionary from memgpt.log import logger -from memgpt.interface import CLIInterface as interface # for printing to terminal + +# from memgpt.interface import CLIInterface as interface # for printing to terminal +from memgpt.streaming_interface import StreamingRefreshCLIInterface as interface # for printing to terminal from memgpt.cli.cli_config import configure import memgpt.presets.presets as presets import memgpt.utils as utils diff --git a/memgpt/llm_api/llm_api_tools.py b/memgpt/llm_api/llm_api_tools.py index c3ab488f2b..1cf72ba853 100644 --- a/memgpt/llm_api/llm_api_tools.py +++ b/memgpt/llm_api/llm_api_tools.py @@ -176,9 +176,10 @@ def create( if stream: data.stream = True - from memgpt.streaming_interface import StreamingCLIInterface + from memgpt.streaming_interface import StreamingCLIInterface, StreamingRefreshCLIInterface - stream_inferface = StreamingCLIInterface() + # stream_inferface = StreamingCLIInterface() + stream_inferface = StreamingRefreshCLIInterface() return openai_chat_completions_process_stream( url=agent_state.llm_config.model_endpoint, # https://api.openai.com/v1 -> https://api.openai.com/v1/chat/completions api_key=credentials.openai_key, diff --git a/memgpt/llm_api/openai.py b/memgpt/llm_api/openai.py index 4a298e87d4..905439de7a 100644 --- a/memgpt/llm_api/openai.py +++ b/memgpt/llm_api/openai.py @@ -19,7 +19,7 @@ from memgpt.utils import smart_urljoin, get_utc_time from memgpt.local_llm.utils import num_tokens_from_messages, num_tokens_from_functions from memgpt.interface import AgentInterface -from memgpt.streaming_interface import AgentStreamingInterface +from memgpt.streaming_interface import AgentChunkStreamingInterface, AgentRefreshStreamingInterface OPENAI_SSE_DONE = "[DONE]" @@ -79,7 +79,7 @@ def openai_chat_completions_process_stream( url: str, api_key: str, chat_completion_request: ChatCompletionRequest, - stream_inferface: Optional[AgentStreamingInterface] = None, + stream_inferface: Optional[Union[AgentChunkStreamingInterface, AgentRefreshStreamingInterface]] = None, ) -> ChatCompletionResponse: """Process a streaming completion response, and return a ChatCompletionRequest at the end. @@ -91,7 +91,7 @@ def openai_chat_completions_process_stream( # Count the prompt tokens # TODO move to post-request? chat_history = [m.model_dump(exclude_none=True) for m in chat_completion_request.messages] - print(chat_history) + # print(chat_history) prompt_tokens = num_tokens_from_messages( messages=chat_history, @@ -138,7 +138,12 @@ def openai_chat_completions_process_stream( # print(chat_completion_chunk) if stream_inferface: - stream_inferface.process_chunk(chat_completion_chunk) + if isinstance(stream_inferface, AgentChunkStreamingInterface): + stream_inferface.process_chunk(chat_completion_chunk) + elif isinstance(stream_inferface, AgentRefreshStreamingInterface): + stream_inferface.process_refresh(chat_completion_response) + else: + raise TypeError(stream_inferface) if chunk_idx == 0: # initialize the choice objects which we will increment with the deltas @@ -255,8 +260,7 @@ def openai_chat_completions_request( headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"} data = chat_completion_request.model_dump(exclude_none=True) - # import json - # print(json.dumps(data, indent=2)) + printd("Request:\n", json.dumps(data, indent=2)) # If functions == None, strip from the payload if "functions" in data and data["functions"] is None: diff --git a/memgpt/local_llm/utils.py b/memgpt/local_llm/utils.py index 27db5c9e8d..15c85c286a 100644 --- a/memgpt/local_llm/utils.py +++ b/memgpt/local_llm/utils.py @@ -135,7 +135,7 @@ def num_tokens_from_tool_calls(tool_calls: List[dict], model: str = "gpt-4"): try: encoding = tiktoken.encoding_for_model(model) except KeyError: - print("Warning: model not found. Using cl100k_base encoding.") + # print("Warning: model not found. Using cl100k_base encoding.") encoding = tiktoken.get_encoding("cl100k_base") num_tokens = 0 @@ -166,7 +166,7 @@ def num_tokens_from_messages(messages: List[dict], model: str = "gpt-4") -> int: try: encoding = tiktoken.encoding_for_model(model) except KeyError: - print("Warning: model not found. Using cl100k_base encoding.") + # print("Warning: model not found. Using cl100k_base encoding.") encoding = tiktoken.get_encoding("cl100k_base") if model in { "gpt-3.5-turbo-0613", @@ -182,10 +182,10 @@ def num_tokens_from_messages(messages: List[dict], model: str = "gpt-4") -> int: tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n tokens_per_name = -1 # if there's a name, the role is omitted elif "gpt-3.5-turbo" in model: - print("Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.") + # print("Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.") return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0613") elif "gpt-4" in model: - print("Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.") + # print("Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.") return num_tokens_from_messages(messages, model="gpt-4-0613") else: raise NotImplementedError( diff --git a/memgpt/main.py b/memgpt/main.py index f7b4375807..c973b2310a 100644 --- a/memgpt/main.py +++ b/memgpt/main.py @@ -13,7 +13,9 @@ console = Console() from memgpt.agent_store.storage import StorageConnector, TableType -from memgpt.interface import CLIInterface as interface # for printing to terminal + +# from memgpt.interface import CLIInterface as interface # for printing to terminal +from memgpt.streaming_interface import StreamingRefreshCLIInterface as interface # for printing to terminal from memgpt.config import MemGPTConfig import memgpt.agent as agent import memgpt.system as system @@ -26,6 +28,8 @@ # import benchmark from memgpt.benchmark.benchmark import bench +interface = interface() + app = typer.Typer(pretty_exceptions_enable=False) app.command(name="run")(run) app.command(name="version")(version) @@ -47,7 +51,7 @@ def clear_line(strip_ui=False): - if strip_ui: + if True or strip_ui: return if os.name == "nt": # for windows console.print("\033[A\033[K", end="") diff --git a/memgpt/streaming_interface.py b/memgpt/streaming_interface.py index d344e9d829..20573ce862 100644 --- a/memgpt/streaming_interface.py +++ b/memgpt/streaming_interface.py @@ -1,17 +1,23 @@ from abc import ABC, abstractmethod import json import re +import sys from typing import List, Optional -from colorama import Fore, Style, init +# from colorama import Fore, Style, init +from rich.console import Console +from rich.live import Live +from rich.markup import escape +from rich.style import Style +from rich.text import Text from memgpt.utils import printd from memgpt.constants import CLI_WARNING_PREFIX, JSON_LOADS_STRICT from memgpt.data_types import Message -from memgpt.models.chat_completion_response import ChatCompletionChunkResponse +from memgpt.models.chat_completion_response import ChatCompletionChunkResponse, ChatCompletionResponse from memgpt.interface import AgentInterface, CLIInterface -init(autoreset=True) +# init(autoreset=True) # DEBUG = True # puts full message outputs in the terminal DEBUG = False # only dumps important messages in the terminal @@ -19,7 +25,7 @@ STRIP_UI = False -class AgentStreamingInterface(ABC): +class AgentChunkStreamingInterface(ABC): """Interfaces handle MemGPT-related events (observer pattern) The 'msg' args provides the scoped message, and the optional Message arg can provide additional metadata. @@ -61,7 +67,7 @@ def stream_end(self): raise NotImplementedError -class StreamingCLIInterface(AgentStreamingInterface): +class StreamingCLIInterface(AgentChunkStreamingInterface): """Version of the CLI interface that attaches to a stream generator and prints along the way. When a chunk is received, we write the delta to the buffer. If the buffer type has changed, @@ -198,3 +204,228 @@ def print_messages_raw(message_sequence: List[Message]): @staticmethod def step_yield(): pass + + +class AgentRefreshStreamingInterface(ABC): + """Same as the ChunkStreamingInterface, but + + The 'msg' args provides the scoped message, and the optional Message arg can provide additional metadata. + """ + + @abstractmethod + def user_message(self, msg: str, msg_obj: Optional[Message] = None): + """MemGPT receives a user message""" + raise NotImplementedError + + @abstractmethod + def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None): + """MemGPT generates some internal monologue""" + raise NotImplementedError + + @abstractmethod + def assistant_message(self, msg: str, msg_obj: Optional[Message] = None): + """MemGPT uses send_message""" + raise NotImplementedError + + @abstractmethod + def function_message(self, msg: str, msg_obj: Optional[Message] = None): + """MemGPT calls a function""" + raise NotImplementedError + + @abstractmethod + def process_refresh(self, response: ChatCompletionResponse): + """Process a streaming chunk from an OpenAI-compatible server""" + raise NotImplementedError + + @abstractmethod + def stream_start(self): + """Any setup required before streaming begins""" + raise NotImplementedError + + @abstractmethod + def stream_end(self): + """Any cleanup required after streaming ends""" + raise NotImplementedError + + +# TODO fix this vile abuse of @staticmethod + +# CLIInterface is static/stateless +# nonstreaming_interface = CLIInterface() +console = Console() +live = Live("", console=console, refresh_per_second=10) +# live.start() # Start the Live display context and keep it running +fancy = True +separate_send_message = True + + +class StreamingRefreshCLIInterface(AgentRefreshStreamingInterface): + """Version of the CLI interface that attaches to a stream generator and refreshes a render of the message at every step. + + We maintain the partial message state in the interface state, and on each + process chunk we: + (1) update the partial message state, + (2) refresh/rewrite the state to the screen. + """ + + nonstreaming_interface = CLIInterface + + # def __init__(self, fancy: bool = True): + # """Initialize the streaming CLI interface state.""" + # # self.console = Console() + + # # Using `Live` with `refresh_per_second` parameter to limit the refresh rate, avoiding excessive updates + # # self.live = Live("", console=self.console, refresh_per_second=10) + # # self.live.start() # Start the Live display context and keep it running + + # # Use italics / emoji? + # self.fancy = fancy + + # def update_output(self, content: str): + # """Update the displayed output with new content.""" + # # We use the `Live` object's update mechanism to refresh content without clearing the console + # if not fancy: + # content = escape(content) + # self.live.update(self.console.render_str(content), refresh=True) + + # def process_refresh(self, response: ChatCompletionResponse): + # """Process the response to rewrite the current output buffer.""" + # if not response.choices: + # return # Early exit if there are no choices + + # choice = response.choices[0] + # inner_thoughts = choice.message.content if choice.message.content else "" + # tool_calls = choice.message.tool_calls if choice.message.tool_calls else [] + + # if self.fancy: + # message_string = f"💭 [italic]{inner_thoughts}[/italic]" if inner_thoughts else "" + # else: + # message_string = "[inner thoughts] " + inner_thoughts if inner_thoughts else "" + + # if tool_calls: + # function_call = tool_calls[0].function + # function_name = function_call.name # Function name, can be an empty string + # function_args = function_call.arguments # Function arguments, can be an empty string + # if message_string: + # message_string += "\n" + # message_string += f"{function_name}({function_args})" + + # self.update_output(message_string) + + # def stream_start(self): + # self.live.start() # Start the Live display context and keep it running + + # def stream_end(self): + # if self.live.is_started: + # self.live.stop() + + @staticmethod + def update_output(content: str): + """Update the displayed output with new content.""" + # We use the `Live` object's update mechanism to refresh content without clearing the console + if not fancy: + content = escape(content) + live.update(console.render_str(content), refresh=True) + + @staticmethod + def process_refresh(response: ChatCompletionResponse): + """Process the response to rewrite the current output buffer.""" + if not response.choices: + return # Early exit if there are no choices + + choice = response.choices[0] + inner_thoughts = choice.message.content if choice.message.content else "" + tool_calls = choice.message.tool_calls if choice.message.tool_calls else [] + + if fancy: + message_string = f"💭 [italic]{inner_thoughts}[/italic]" if inner_thoughts else "" + else: + message_string = "[inner thoughts] " + inner_thoughts if inner_thoughts else "" + + if tool_calls: + function_call = tool_calls[0].function + function_name = function_call.name # Function name, can be an empty string + function_args = function_call.arguments # Function arguments, can be an empty string + if message_string: + message_string += "\n" + # special case here for send_message + if separate_send_message and function_name == "send_message": + try: + message = json.loads(function_args)["message"] + except: + prefix = '{\n "message": "' + if len(function_args) < len(prefix): + message = "..." + elif function_args.startswith(prefix): + message = function_args[len(prefix) :] + else: + message = function_args + message_string += f"🤖 [bold yellow]{message}[/bold yellow]" + else: + message_string += f"{function_name}({function_args})" + + StreamingRefreshCLIInterface.update_output(message_string) + + @staticmethod + def stream_start(): + print() + live.start() # Start the Live display context and keep it running + + @staticmethod + def stream_end(): + global live + if live.is_started: + live.stop() + print() + live = Live("", console=console, refresh_per_second=10) + + @staticmethod + def important_message(msg: str): + StreamingCLIInterface.nonstreaming_interface.important_message(msg) + + @staticmethod + def warning_message(msg: str): + StreamingCLIInterface.nonstreaming_interface.warning_message(msg) + + @staticmethod + def internal_monologue(msg: str, msg_obj: Optional[Message] = None): + return + # StreamingCLIInterface.nonstreaming_interface.internal_monologue(msg, msg_obj) + + @staticmethod + def assistant_message(msg: str, msg_obj: Optional[Message] = None): + if separate_send_message: + return + StreamingCLIInterface.nonstreaming_interface.assistant_message(msg, msg_obj) + + @staticmethod + def memory_message(msg: str, msg_obj: Optional[Message] = None): + StreamingCLIInterface.nonstreaming_interface.memory_message(msg, msg_obj) + + @staticmethod + def system_message(msg: str, msg_obj: Optional[Message] = None): + StreamingCLIInterface.nonstreaming_interface.system_message(msg, msg_obj) + + @staticmethod + def user_message(msg: str, msg_obj: Optional[Message] = None, raw: bool = False, dump: bool = False, debug: bool = DEBUG): + StreamingCLIInterface.nonstreaming_interface.user_message(msg, msg_obj) + + @staticmethod + def function_message(msg: str, msg_obj: Optional[Message] = None, debug: bool = DEBUG): + StreamingCLIInterface.nonstreaming_interface.function_message(msg, msg_obj) + + @staticmethod + def print_messages(message_sequence: List[Message], dump=False): + StreamingCLIInterface.nonstreaming_interface.print_messages(message_sequence, dump) + + @staticmethod + def print_messages_simple(message_sequence: List[Message]): + StreamingCLIInterface.nonstreaming_interface.print_messages_simple(message_sequence) + + @staticmethod + def print_messages_raw(message_sequence: List[Message]): + StreamingCLIInterface.nonstreaming_interface.print_messages_raw(message_sequence) + + @staticmethod + def step_yield(): + pass From 7c0d24638f4384c4b08431fc9de86d3f5807e9f8 Mon Sep 17 00:00:00 2001 From: cpacker Date: Tue, 16 Apr 2024 22:46:29 -0700 Subject: [PATCH 7/9] pass stream as flag via CLI, fix bug with non-streaming --- memgpt/agent.py | 7 ++ memgpt/cli/cli.py | 6 +- memgpt/llm_api/llm_api_tools.py | 25 +++-- memgpt/llm_api/openai.py | 157 ++++++++++++++++++++------------ memgpt/main.py | 20 ++-- memgpt/streaming_interface.py | 17 +++- 6 files changed, 151 insertions(+), 81 deletions(-) diff --git a/memgpt/agent.py b/memgpt/agent.py index 26a903bf0a..d24181a74e 100644 --- a/memgpt/agent.py +++ b/memgpt/agent.py @@ -403,6 +403,7 @@ def _get_ai_reply( message_sequence: List[Message], function_call: str = "auto", first_message: bool = False, # hint + stream: bool = False, # TODO move to config? ) -> chat_completion_response.ChatCompletionResponse: """Get response from LLM API""" try: @@ -414,6 +415,9 @@ def _get_ai_reply( function_call=function_call, # hint first_message=first_message, + # streaming + stream=stream, + stream_inferface=self.interface, ) # special case for 'length' if response.choices[0].finish_reason == "length": @@ -628,6 +632,7 @@ def step( skip_verify: bool = False, return_dicts: bool = True, # if True, return dicts, if False, return Message objects recreate_message_timestamp: bool = True, # if True, when input is a Message type, recreated the 'created_at' field + stream: bool = False, # TODO move to config? ) -> Tuple[List[Union[dict, Message]], bool, bool, bool]: """Top-level event message handler for the MemGPT agent""" @@ -710,6 +715,7 @@ def validate_json(user_message_text: str, raise_on_error: bool) -> str: response = self._get_ai_reply( message_sequence=input_message_sequence, first_message=True, # passed through to the prompt formatter + stream=stream, ) if verify_first_message_correctness(response, require_monologue=self.first_message_verify_mono): break @@ -721,6 +727,7 @@ def validate_json(user_message_text: str, raise_on_error: bool) -> str: else: response = self._get_ai_reply( message_sequence=input_message_sequence, + stream=stream, ) # Step 2: check if LLM wanted to call a function diff --git a/memgpt/cli/cli.py b/memgpt/cli/cli.py index 2d2e4185f9..02260050b0 100644 --- a/memgpt/cli/cli.py +++ b/memgpt/cli/cli.py @@ -447,6 +447,8 @@ def run( debug: Annotated[bool, typer.Option(help="Use --debug to enable debugging output")] = False, no_verify: Annotated[bool, typer.Option(help="Bypass message verification")] = False, yes: Annotated[bool, typer.Option("-y", help="Skip confirmation prompt and use defaults")] = False, + # streaming + stream: Annotated[bool, typer.Option(help="Enables message streaming in the CLI (if the backend supports it)")] = False, ): """Start chatting with an MemGPT agent @@ -712,7 +714,9 @@ def run( from memgpt.main import run_agent_loop print() # extra space - run_agent_loop(memgpt_agent, config, first, ms, no_verify) # TODO: add back no_verify + run_agent_loop( + memgpt_agent=memgpt_agent, config=config, first=first, ms=ms, no_verify=no_verify, stream=stream + ) # TODO: add back no_verify def delete_agent( diff --git a/memgpt/llm_api/llm_api_tools.py b/memgpt/llm_api/llm_api_tools.py index 1cf72ba853..fbbaa2f6b2 100644 --- a/memgpt/llm_api/llm_api_tools.py +++ b/memgpt/llm_api/llm_api_tools.py @@ -3,13 +3,14 @@ import requests import os import time -from typing import List +from typing import List, Optional, Union from memgpt.credentials import MemGPTCredentials from memgpt.local_llm.chat_completion_proxy import get_chat_completion from memgpt.constants import CLI_WARNING_PREFIX from memgpt.models.chat_completion_response import ChatCompletionResponse from memgpt.models.chat_completion_request import ChatCompletionRequest, Tool, cast_message_to_subtype +from memgpt.streaming_interface import AgentChunkStreamingInterface, AgentRefreshStreamingInterface from memgpt.data_types import AgentState, Message @@ -126,18 +127,17 @@ def wrapper(*args, **kwargs): def create( agent_state: AgentState, messages: List[Message], - functions=None, - functions_python=None, - function_call="auto", + functions: list = None, + functions_python: list = None, + function_call: str = "auto", # hint - first_message=False, + first_message: bool = False, # use tool naming? # if false, will use deprecated 'functions' style - use_tool_naming=True, + use_tool_naming: bool = True, # streaming? - # stream=False, - stream=True, - stream_inferface=None, + stream: bool = False, + stream_inferface: Optional[Union[AgentRefreshStreamingInterface, AgentChunkStreamingInterface]] = None, ) -> ChatCompletionResponse: """Return response to chat completion with backoff""" from memgpt.utils import printd @@ -176,10 +176,9 @@ def create( if stream: data.stream = True - from memgpt.streaming_interface import StreamingCLIInterface, StreamingRefreshCLIInterface - - # stream_inferface = StreamingCLIInterface() - stream_inferface = StreamingRefreshCLIInterface() + assert isinstance(stream_inferface, AgentChunkStreamingInterface) or isinstance( + stream_inferface, AgentRefreshStreamingInterface + ), type(stream_inferface) return openai_chat_completions_process_stream( url=agent_state.llm_config.model_endpoint, # https://api.openai.com/v1 -> https://api.openai.com/v1/chat/completions api_key=credentials.openai_key, diff --git a/memgpt/llm_api/openai.py b/memgpt/llm_api/openai.py index 905439de7a..8b4d347471 100644 --- a/memgpt/llm_api/openai.py +++ b/memgpt/llm_api/openai.py @@ -132,7 +132,7 @@ def openai_chat_completions_process_stream( n_chunks = 0 # approx == n_tokens try: for chunk_idx, chat_completion_chunk in enumerate( - openai_chat_completions_request(url=url, api_key=api_key, chat_completion_request=chat_completion_request) + openai_chat_completions_request_stream(url=url, api_key=api_key, chat_completion_request=chat_completion_request) ): assert isinstance(chat_completion_chunk, ChatCompletionChunkResponse), type(chat_completion_chunk) # print(chat_completion_chunk) @@ -242,11 +242,98 @@ def openai_chat_completions_process_stream( return chat_completion_response +def _sse_post(url: str, data: dict, headers: dict) -> Generator[ChatCompletionChunkResponse, None, None]: + + with httpx.Client() as client: + with connect_sse(client, method="POST", url=url, json=data, headers=headers) as event_source: + try: + for sse in event_source.iter_sse(): + # printd(sse.event, sse.data, sse.id, sse.retry) + if sse.data == OPENAI_SSE_DONE: + # print("finished") + break + else: + chunk_data = json.loads(sse.data) + # print("chunk_data::", chunk_data) + chunk_object = ChatCompletionChunkResponse(**chunk_data) + # print("chunk_object::", chunk_object) + # id=chunk_data["id"], + # choices=[ChunkChoice], + # model=chunk_data["model"], + # system_fingerprint=chunk_data["system_fingerprint"] + # ) + yield chunk_object + + except SSEError as e: + if "application/json" in str(e): # Check if the error is because of JSON response + response = client.post(url=url, json=data, headers=headers) # Make the request again to get the JSON response + if response.headers["Content-Type"].startswith("application/json"): + error_details = response.json() # Parse the JSON to get the error message + print("Error:", error_details) + print("Reqeust:", vars(response.request)) + else: + print("Failed to retrieve JSON error message.") + else: + print("SSEError not related to 'application/json' content type.") + + # Optionally re-raise the exception if you need to propagate it + raise e + + except Exception as e: + if event_source.response.request is not None: + print("HTTP Request:", vars(event_source.response.request)) + if event_source.response is not None: + print("HTTP Status:", event_source.response.status_code) + print("HTTP Headers:", event_source.response.headers) + # print("HTTP Body:", event_source.response.text) + print("Exception message:", str(e)) + raise e + + +def openai_chat_completions_request_stream( + url: str, + api_key: str, + chat_completion_request: ChatCompletionRequest, +) -> Generator[ChatCompletionChunkResponse, None, None]: + from memgpt.utils import printd + + url = smart_urljoin(url, "chat/completions") + headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"} + data = chat_completion_request.model_dump(exclude_none=True) + + printd("Request:\n", json.dumps(data, indent=2)) + + # If functions == None, strip from the payload + if "functions" in data and data["functions"] is None: + data.pop("functions") + data.pop("function_call", None) # extra safe, should exist always (default="auto") + + if "tools" in data and data["tools"] is None: + data.pop("tools") + data.pop("tool_choice", None) # extra safe, should exist always (default="auto") + + printd(f"Sending request to {url}") + try: + return _sse_post(url=url, data=data, headers=headers) + except requests.exceptions.HTTPError as http_err: + # Handle HTTP errors (e.g., response 4XX, 5XX) + printd(f"Got HTTPError, exception={http_err}, payload={data}") + raise http_err + except requests.exceptions.RequestException as req_err: + # Handle other requests-related errors (e.g., connection error) + printd(f"Got RequestException, exception={req_err}") + raise req_err + except Exception as e: + # Handle other potential errors + printd(f"Got unknown Exception, exception={e}") + raise e + + def openai_chat_completions_request( url: str, api_key: str, chat_completion_request: ChatCompletionRequest, -) -> Union[ChatCompletionResponse, Generator[ChatCompletionChunkResponse, None, None]]: +) -> ChatCompletionResponse: """Send a ChatCompletion request to an OpenAI-compatible server If request.stream == True, will yield ChatCompletionChunkResponses @@ -273,63 +360,15 @@ def openai_chat_completions_request( printd(f"Sending request to {url}") try: - if data["stream"] == True: - - with httpx.Client() as client: - with connect_sse(client, method="POST", url=url, json=data, headers=headers) as event_source: - try: - for sse in event_source.iter_sse(): - # printd(sse.event, sse.data, sse.id, sse.retry) - if sse.data == OPENAI_SSE_DONE: - # print("finished") - break - else: - chunk_data = json.loads(sse.data) - # print("chunk_data::", chunk_data) - chunk_object = ChatCompletionChunkResponse(**chunk_data) - # print("chunk_object::", chunk_object) - # id=chunk_data["id"], - # choices=[ChunkChoice], - # model=chunk_data["model"], - # system_fingerprint=chunk_data["system_fingerprint"] - # ) - yield chunk_object - - except SSEError as e: - if "application/json" in str(e): # Check if the error is because of JSON response - response = client.post(url=url, json=data, headers=headers) # Make the request again to get the JSON response - if response.headers["Content-Type"].startswith("application/json"): - error_details = response.json() # Parse the JSON to get the error message - print("Error:", error_details) - print("Reqeust:", vars(response.request)) - else: - print("Failed to retrieve JSON error message.") - else: - print("SSEError not related to 'application/json' content type.") - - # Optionally re-raise the exception if you need to propagate it - raise e - - except Exception as e: - if event_source.response.request is not None: - print("HTTP Request:", vars(event_source.response.request)) - if event_source.response is not None: - print("HTTP Status:", event_source.response.status_code) - print("HTTP Headers:", event_source.response.headers) - # print("HTTP Body:", event_source.response.text) - print("Exception message:", str(e)) - raise e - - else: - response = requests.post(url, headers=headers, json=data) - printd(f"response = {response}") - response.raise_for_status() # Raises HTTPError for 4XX/5XX status - - response = response.json() # convert to dict from string - printd(f"response.json = {response}") - - response = ChatCompletionResponse(**response) # convert to 'dot-dict' style which is the openai python client default - return response + response = requests.post(url, headers=headers, json=data) + printd(f"response = {response}") + response.raise_for_status() # Raises HTTPError for 4XX/5XX status + + response = response.json() # convert to dict from string + printd(f"response.json = {response}") + + response = ChatCompletionResponse(**response) # convert to 'dot-dict' style which is the openai python client default + return response except requests.exceptions.HTTPError as http_err: # Handle HTTP errors (e.g., response 4XX, 5XX) printd(f"Got HTTPError, exception={http_err}, payload={data}") diff --git a/memgpt/main.py b/memgpt/main.py index c973b2310a..1f629cfa42 100644 --- a/memgpt/main.py +++ b/memgpt/main.py @@ -60,7 +60,10 @@ def clear_line(strip_ui=False): sys.stdout.flush() -def run_agent_loop(memgpt_agent, config: MemGPTConfig, first, ms: MetadataStore, no_verify=False, cfg=None, strip_ui=False): +def run_agent_loop(memgpt_agent, config: MemGPTConfig, first, ms: MetadataStore, no_verify=False, cfg=None, strip_ui=False, stream=False): + # TODO remove + interface.toggle_streaming(on=stream) + counter = 0 user_input = None skip_next_user_input = False @@ -349,7 +352,10 @@ def run_agent_loop(memgpt_agent, config: MemGPTConfig, first, ms: MetadataStore, def process_agent_step(user_message, no_verify): new_messages, heartbeat_request, function_failed, token_warning, tokens_accumulated = memgpt_agent.step( - user_message, first_message=False, skip_verify=no_verify + user_message, + first_message=False, + skip_verify=no_verify, + stream=stream, ) skip_next_user_input = False @@ -371,10 +377,12 @@ def process_agent_step(user_message, no_verify): new_messages, user_message, skip_next_user_input = process_agent_step(user_message, no_verify) break else: - # with console.status("[bold cyan]Thinking...") as status: - # new_messages, user_message, skip_next_user_input = process_agent_step(user_message, no_verify) - # break - new_messages, user_message, skip_next_user_input = process_agent_step(user_message, no_verify) + if stream: + # Don't display the "Thinking..." if streaming + new_messages, user_message, skip_next_user_input = process_agent_step(user_message, no_verify) + else: + with console.status("[bold cyan]Thinking...") as status: + new_messages, user_message, skip_next_user_input = process_agent_step(user_message, no_verify) break except KeyboardInterrupt: print("User interrupt occurred.") diff --git a/memgpt/streaming_interface.py b/memgpt/streaming_interface.py index 20573ce862..464f37e234 100644 --- a/memgpt/streaming_interface.py +++ b/memgpt/streaming_interface.py @@ -257,6 +257,7 @@ def stream_end(self): # live.start() # Start the Live display context and keep it running fancy = True separate_send_message = True +disable_inner_mono_call = True class StreamingRefreshCLIInterface(AgentRefreshStreamingInterface): @@ -319,6 +320,17 @@ class StreamingRefreshCLIInterface(AgentRefreshStreamingInterface): # if self.live.is_started: # self.live.stop() + @staticmethod + def toggle_streaming(on: bool): + global separate_send_message + global disable_inner_mono_call + if on: + separate_send_message = True + disable_inner_mono_call = True + else: + separate_send_message = False + disable_inner_mono_call = False + @staticmethod def update_output(content: str): """Update the displayed output with new content.""" @@ -389,8 +401,9 @@ def warning_message(msg: str): @staticmethod def internal_monologue(msg: str, msg_obj: Optional[Message] = None): - return - # StreamingCLIInterface.nonstreaming_interface.internal_monologue(msg, msg_obj) + if disable_inner_mono_call: + return + StreamingCLIInterface.nonstreaming_interface.internal_monologue(msg, msg_obj) @staticmethod def assistant_message(msg: str, msg_obj: Optional[Message] = None): From 8d98e4cbc452669121f32543ee6ae020fea5e91c Mon Sep 17 00:00:00 2001 From: cpacker Date: Wed, 17 Apr 2024 19:04:21 -0700 Subject: [PATCH 8/9] made streaminginterface stateful to remove global vars --- memgpt/main.py | 41 +++++++---- memgpt/streaming_interface.py | 134 +++++++++++----------------------- 2 files changed, 69 insertions(+), 106 deletions(-) diff --git a/memgpt/main.py b/memgpt/main.py index 1f629cfa42..2a3a740b35 100644 --- a/memgpt/main.py +++ b/memgpt/main.py @@ -10,12 +10,10 @@ from rich.console import Console from memgpt.constants import FUNC_FAILED_HEARTBEAT_MESSAGE, JSON_ENSURE_ASCII, JSON_LOADS_STRICT, REQ_HEARTBEAT_MESSAGE -console = Console() - from memgpt.agent_store.storage import StorageConnector, TableType # from memgpt.interface import CLIInterface as interface # for printing to terminal -from memgpt.streaming_interface import StreamingRefreshCLIInterface as interface # for printing to terminal +from memgpt.streaming_interface import AgentRefreshStreamingInterface from memgpt.config import MemGPTConfig import memgpt.agent as agent import memgpt.system as system @@ -28,7 +26,7 @@ # import benchmark from memgpt.benchmark.benchmark import bench -interface = interface() +# interface = interface() app = typer.Typer(pretty_exceptions_enable=False) app.command(name="run")(run) @@ -50,8 +48,8 @@ app.command(name="delete-agent")(delete_agent) -def clear_line(strip_ui=False): - if True or strip_ui: +def clear_line(console, strip_ui=False): + if strip_ui: return if os.name == "nt": # for windows console.print("\033[A\033[K", end="") @@ -60,9 +58,18 @@ def clear_line(strip_ui=False): sys.stdout.flush() -def run_agent_loop(memgpt_agent, config: MemGPTConfig, first, ms: MetadataStore, no_verify=False, cfg=None, strip_ui=False, stream=False): - # TODO remove - interface.toggle_streaming(on=stream) +def run_agent_loop( + memgpt_agent: agent.Agent, config: MemGPTConfig, first, ms: MetadataStore, no_verify=False, cfg=None, strip_ui=False, stream=False +): + if isinstance(memgpt_agent.interface, AgentRefreshStreamingInterface): + # memgpt_agent.interface.toggle_streaming(on=stream) + if not stream: + memgpt_agent.interface = memgpt_agent.interface.nonstreaming_interface + + if hasattr(memgpt_agent.interface, "console"): + console = memgpt_agent.interface.console + else: + console = Console() counter = 0 user_input = None @@ -71,8 +78,8 @@ def run_agent_loop(memgpt_agent, config: MemGPTConfig, first, ms: MetadataStore, USER_GOES_FIRST = first if not USER_GOES_FIRST: - console.input("[bold cyan]Hit enter to begin (will request first MemGPT message)[/bold cyan]") - clear_line(strip_ui) + console.input("[bold cyan]Hit enter to begin (will request first MemGPT message)[/bold cyan]\n") + clear_line(console, strip_ui=strip_ui) print() multiline_input = False @@ -80,12 +87,14 @@ def run_agent_loop(memgpt_agent, config: MemGPTConfig, first, ms: MetadataStore, while True: if not skip_next_user_input and (counter > 0 or USER_GOES_FIRST): # Ask for user input + print() user_input = questionary.text( "Enter your message:", multiline=multiline_input, qmark=">", ).ask() - clear_line(strip_ui) + clear_line(console, strip_ui=strip_ui) + print() # Gracefully exit on Ctrl-C/D if user_input is None: @@ -163,13 +172,13 @@ def run_agent_loop(memgpt_agent, config: MemGPTConfig, first, ms: MetadataStore, command = user_input.strip().split() amount = int(command[1]) if len(command) > 1 and command[1].isdigit() else 0 if amount == 0: - interface.print_messages(memgpt_agent._messages, dump=True) + memgpt_agent.interface.print_messages(memgpt_agent._messages, dump=True) else: - interface.print_messages(memgpt_agent._messages[-min(amount, len(memgpt_agent.messages)) :], dump=True) + memgpt_agent.interface.print_messages(memgpt_agent._messages[-min(amount, len(memgpt_agent.messages)) :], dump=True) continue elif user_input.lower() == "/dumpraw": - interface.print_messages_raw(memgpt_agent._messages) + memgpt_agent.interface.print_messages_raw(memgpt_agent._messages) continue elif user_input.lower() == "/memory": @@ -319,7 +328,7 @@ def run_agent_loop(memgpt_agent, config: MemGPTConfig, first, ms: MetadataStore, # No skip options elif user_input.lower() == "/wipe": - memgpt_agent = agent.Agent(interface) + memgpt_agent = agent.Agent(memgpt_agent.interface) user_message = None elif user_input.lower() == "/heartbeat": diff --git a/memgpt/streaming_interface.py b/memgpt/streaming_interface.py index 464f37e234..928810da8a 100644 --- a/memgpt/streaming_interface.py +++ b/memgpt/streaming_interface.py @@ -247,17 +247,10 @@ def stream_end(self): """Any cleanup required after streaming ends""" raise NotImplementedError - -# TODO fix this vile abuse of @staticmethod - -# CLIInterface is static/stateless -# nonstreaming_interface = CLIInterface() -console = Console() -live = Live("", console=console, refresh_per_second=10) -# live.start() # Start the Live display context and keep it running -fancy = True -separate_send_message = True -disable_inner_mono_call = True + @abstractmethod + def toggle_streaming(self, on: bool): + """Toggle streaming on/off (off = regular CLI interface)""" + raise NotImplementedError class StreamingRefreshCLIInterface(AgentRefreshStreamingInterface): @@ -271,85 +264,48 @@ class StreamingRefreshCLIInterface(AgentRefreshStreamingInterface): nonstreaming_interface = CLIInterface - # def __init__(self, fancy: bool = True): - # """Initialize the streaming CLI interface state.""" - # # self.console = Console() - - # # Using `Live` with `refresh_per_second` parameter to limit the refresh rate, avoiding excessive updates - # # self.live = Live("", console=self.console, refresh_per_second=10) - # # self.live.start() # Start the Live display context and keep it running - - # # Use italics / emoji? - # self.fancy = fancy - - # def update_output(self, content: str): - # """Update the displayed output with new content.""" - # # We use the `Live` object's update mechanism to refresh content without clearing the console - # if not fancy: - # content = escape(content) - # self.live.update(self.console.render_str(content), refresh=True) + def __init__(self, fancy: bool = True, separate_send_message: bool = True, disable_inner_mono_call: bool = True): + """Initialize the streaming CLI interface state.""" + self.console = Console() - # def process_refresh(self, response: ChatCompletionResponse): - # """Process the response to rewrite the current output buffer.""" - # if not response.choices: - # return # Early exit if there are no choices + # Using `Live` with `refresh_per_second` parameter to limit the refresh rate, avoiding excessive updates + self.live = Live("", console=self.console, refresh_per_second=10) + # self.live.start() # Start the Live display context and keep it running - # choice = response.choices[0] - # inner_thoughts = choice.message.content if choice.message.content else "" - # tool_calls = choice.message.tool_calls if choice.message.tool_calls else [] + # Use italics / emoji? + self.fancy = fancy - # if self.fancy: - # message_string = f"💭 [italic]{inner_thoughts}[/italic]" if inner_thoughts else "" - # else: - # message_string = "[inner thoughts] " + inner_thoughts if inner_thoughts else "" + self.streaming = True + self.separate_send_message = separate_send_message + self.disable_inner_mono_call = disable_inner_mono_call - # if tool_calls: - # function_call = tool_calls[0].function - # function_name = function_call.name # Function name, can be an empty string - # function_args = function_call.arguments # Function arguments, can be an empty string - # if message_string: - # message_string += "\n" - # message_string += f"{function_name}({function_args})" - - # self.update_output(message_string) - - # def stream_start(self): - # self.live.start() # Start the Live display context and keep it running - - # def stream_end(self): - # if self.live.is_started: - # self.live.stop() - - @staticmethod - def toggle_streaming(on: bool): - global separate_send_message - global disable_inner_mono_call + def toggle_streaming(self, on: bool): + self.streaming = on if on: - separate_send_message = True - disable_inner_mono_call = True + self.separate_send_message = True + self.disable_inner_mono_call = True else: - separate_send_message = False - disable_inner_mono_call = False + self.separate_send_message = False + self.disable_inner_mono_call = False - @staticmethod - def update_output(content: str): + def update_output(self, content: str): """Update the displayed output with new content.""" # We use the `Live` object's update mechanism to refresh content without clearing the console - if not fancy: + if not self.fancy: content = escape(content) - live.update(console.render_str(content), refresh=True) + self.live.update(self.console.render_str(content), refresh=True) - @staticmethod - def process_refresh(response: ChatCompletionResponse): + def process_refresh(self, response: ChatCompletionResponse): """Process the response to rewrite the current output buffer.""" if not response.choices: + self.update_output("💭 [italic]...[/italic]") return # Early exit if there are no choices choice = response.choices[0] inner_thoughts = choice.message.content if choice.message.content else "" tool_calls = choice.message.tool_calls if choice.message.tool_calls else [] - if fancy: + if self.fancy: message_string = f"💭 [italic]{inner_thoughts}[/italic]" if inner_thoughts else "" else: message_string = "[inner thoughts] " + inner_thoughts if inner_thoughts else "" @@ -361,7 +317,7 @@ def process_refresh(response: ChatCompletionResponse): if message_string: message_string += "\n" # special case here for send_message - if separate_send_message and function_name == "send_message": + if self.separate_send_message and function_name == "send_message": try: message = json.loads(function_args)["message"] except: @@ -376,20 +332,20 @@ def process_refresh(response: ChatCompletionResponse): else: message_string += f"{function_name}({function_args})" - StreamingRefreshCLIInterface.update_output(message_string) - - @staticmethod - def stream_start(): - print() - live.start() # Start the Live display context and keep it running + self.update_output(message_string) - @staticmethod - def stream_end(): - global live - if live.is_started: - live.stop() + def stream_start(self): + if self.streaming: print() - live = Live("", console=console, refresh_per_second=10) + self.live.start() # Start the Live display context and keep it running + self.update_output("💭 [italic]...[/italic]") + + def stream_end(self): + if self.streaming: + if self.live.is_started: + self.live.stop() + print() + self.live = Live("", console=self.console, refresh_per_second=10) @staticmethod def important_message(msg: str): @@ -399,15 +355,13 @@ def important_message(msg: str): def warning_message(msg: str): StreamingCLIInterface.nonstreaming_interface.warning_message(msg) - @staticmethod - def internal_monologue(msg: str, msg_obj: Optional[Message] = None): - if disable_inner_mono_call: + def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None): + if self.disable_inner_mono_call: return StreamingCLIInterface.nonstreaming_interface.internal_monologue(msg, msg_obj) - @staticmethod - def assistant_message(msg: str, msg_obj: Optional[Message] = None): - if separate_send_message: + def assistant_message(self, msg: str, msg_obj: Optional[Message] = None): + if self.separate_send_message: return StreamingCLIInterface.nonstreaming_interface.assistant_message(msg, msg_obj) From d18951960d057b48bfecedac07ea2fdea0093bef Mon Sep 17 00:00:00 2001 From: cpacker Date: Wed, 17 Apr 2024 19:37:41 -0700 Subject: [PATCH 9/9] remove extra newlines on streaming, make config location print a printd --- memgpt/config.py | 3 ++- memgpt/main.py | 6 ++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/memgpt/config.py b/memgpt/config.py index 1aff760d08..1df21fff42 100644 --- a/memgpt/config.py +++ b/memgpt/config.py @@ -90,6 +90,7 @@ def generate_uuid() -> str: def load(cls) -> "MemGPTConfig": # avoid circular import from memgpt.migrate import config_is_compatible, VERSION_CUTOFF + from memgpt.utils import printd if not config_is_compatible(allow_empty=True): error_message = " ".join( @@ -110,7 +111,7 @@ def load(cls) -> "MemGPTConfig": # insure all configuration directories exist cls.create_config_dir() - print(f"Loading config from {config_path}") + printd(f"Loading config from {config_path}") if os.path.exists(config_path): # read existing config config.read(config_path) diff --git a/memgpt/main.py b/memgpt/main.py index 2a3a740b35..c6329aa5b3 100644 --- a/memgpt/main.py +++ b/memgpt/main.py @@ -87,14 +87,16 @@ def run_agent_loop( while True: if not skip_next_user_input and (counter > 0 or USER_GOES_FIRST): # Ask for user input - print() + if not stream: + print() user_input = questionary.text( "Enter your message:", multiline=multiline_input, qmark=">", ).ask() clear_line(console, strip_ui=strip_ui) - print() + if not stream: + print() # Gracefully exit on Ctrl-C/D if user_input is None: