Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add streaming support for OpenAI-compatible endpoints #1262

Merged
merged 9 commits into from
Apr 18, 2024
7 changes: 7 additions & 0 deletions memgpt/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@
recall_memory: Optional[RecallMemory] = None,
include_char_count: bool = True,
):
full_system_message = "\n".join(

Check failure on line 117 in memgpt/agent.py

View workflow job for this annotation

GitHub Actions / Pyright types check (3.11)

No overloads for "join" match the provided arguments (reportCallIssue)
[

Check failure on line 118 in memgpt/agent.py

View workflow job for this annotation

GitHub Actions / Pyright types check (3.11)

Argument of type "list[str | Unknown | None]" cannot be assigned to parameter "iterable" of type "Iterable[str]" in function "join"   Type "Unknown | None" is incompatible with type "str"     "None" is incompatible with "str"   Type "Unknown | None" is incompatible with type "str"     "None" is incompatible with "str" (reportArgumentType)
system,
"\n",
f"### Memory [last modified: {memory_edit_timestamp.strip()}]",
Expand Down Expand Up @@ -403,6 +403,7 @@
message_sequence: List[Message],
function_call: str = "auto",
first_message: bool = False, # hint
stream: bool = False, # TODO move to config?
) -> chat_completion_response.ChatCompletionResponse:
"""Get response from LLM API"""
try:
Expand All @@ -414,6 +415,9 @@
function_call=function_call,
# hint
first_message=first_message,
# streaming
stream=stream,
stream_inferface=self.interface,
)
# special case for 'length'
if response.choices[0].finish_reason == "length":
Expand Down Expand Up @@ -628,6 +632,7 @@
skip_verify: bool = False,
return_dicts: bool = True, # if True, return dicts, if False, return Message objects
recreate_message_timestamp: bool = True, # if True, when input is a Message type, recreated the 'created_at' field
stream: bool = False, # TODO move to config?
) -> Tuple[List[Union[dict, Message]], bool, bool, bool]:
"""Top-level event message handler for the MemGPT agent"""

Expand Down Expand Up @@ -710,6 +715,7 @@
response = self._get_ai_reply(
message_sequence=input_message_sequence,
first_message=True, # passed through to the prompt formatter
stream=stream,
)
if verify_first_message_correctness(response, require_monologue=self.first_message_verify_mono):
break
Expand All @@ -721,6 +727,7 @@
else:
response = self._get_ai_reply(
message_sequence=input_message_sequence,
stream=stream,
)

# Step 2: check if LLM wanted to call a function
Expand Down
10 changes: 8 additions & 2 deletions memgpt/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
import questionary

from memgpt.log import logger
from memgpt.interface import CLIInterface as interface # for printing to terminal

# from memgpt.interface import CLIInterface as interface # for printing to terminal
from memgpt.streaming_interface import StreamingRefreshCLIInterface as interface # for printing to terminal
from memgpt.cli.cli_config import configure
import memgpt.presets.presets as presets
import memgpt.utils as utils
Expand Down Expand Up @@ -445,6 +447,8 @@ def run(
debug: Annotated[bool, typer.Option(help="Use --debug to enable debugging output")] = False,
no_verify: Annotated[bool, typer.Option(help="Bypass message verification")] = False,
yes: Annotated[bool, typer.Option("-y", help="Skip confirmation prompt and use defaults")] = False,
# streaming
stream: Annotated[bool, typer.Option(help="Enables message streaming in the CLI (if the backend supports it)")] = False,
):
"""Start chatting with an MemGPT agent

Expand Down Expand Up @@ -710,7 +714,9 @@ def run(
from memgpt.main import run_agent_loop

print() # extra space
run_agent_loop(memgpt_agent, config, first, ms, no_verify) # TODO: add back no_verify
run_agent_loop(
memgpt_agent=memgpt_agent, config=config, first=first, ms=ms, no_verify=no_verify, stream=stream
) # TODO: add back no_verify


def delete_agent(
Expand Down
3 changes: 2 additions & 1 deletion memgpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def generate_uuid() -> str:
def load(cls) -> "MemGPTConfig":
# avoid circular import
from memgpt.migrate import config_is_compatible, VERSION_CUTOFF
from memgpt.utils import printd

if not config_is_compatible(allow_empty=True):
error_message = " ".join(
Expand All @@ -110,7 +111,7 @@ def load(cls) -> "MemGPTConfig":

# insure all configuration directories exist
cls.create_config_dir()
print(f"Loading config from {config_path}")
printd(f"Loading config from {config_path}")
if os.path.exists(config_path):
# read existing config
config.read(config_path)
Expand Down
42 changes: 30 additions & 12 deletions memgpt/llm_api/llm_api_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,18 @@
import requests
import os
import time
from typing import List
from typing import List, Optional, Union

from memgpt.credentials import MemGPTCredentials
from memgpt.local_llm.chat_completion_proxy import get_chat_completion
from memgpt.constants import CLI_WARNING_PREFIX
from memgpt.models.chat_completion_response import ChatCompletionResponse
from memgpt.models.chat_completion_request import ChatCompletionRequest, Tool, cast_message_to_subtype
from memgpt.streaming_interface import AgentChunkStreamingInterface, AgentRefreshStreamingInterface

from memgpt.data_types import AgentState, Message

from memgpt.llm_api.openai import openai_chat_completions_request
from memgpt.llm_api.openai import openai_chat_completions_request, openai_chat_completions_process_stream
from memgpt.llm_api.azure_openai import azure_openai_chat_completions_request, MODEL_TO_AZURE_ENGINE
from memgpt.llm_api.google_ai import (
google_ai_chat_completions_request,
Expand Down Expand Up @@ -126,14 +127,17 @@ def wrapper(*args, **kwargs):
def create(
agent_state: AgentState,
messages: List[Message],
functions=None,
functions_python=None,
function_call="auto",
functions: list = None,
functions_python: list = None,
function_call: str = "auto",
# hint
first_message=False,
first_message: bool = False,
# use tool naming?
# if false, will use deprecated 'functions' style
use_tool_naming=True,
use_tool_naming: bool = True,
# streaming?
stream: bool = False,
stream_inferface: Optional[Union[AgentRefreshStreamingInterface, AgentChunkStreamingInterface]] = None,
) -> ChatCompletionResponse:
"""Return response to chat completion with backoff"""
from memgpt.utils import printd
Expand Down Expand Up @@ -169,11 +173,25 @@ def create(
function_call=function_call,
user=str(agent_state.user_id),
)
return openai_chat_completions_request(
url=agent_state.llm_config.model_endpoint, # https://api.openai.com/v1 -> https://api.openai.com/v1/chat/completions
api_key=credentials.openai_key,
data=data,
)

if stream:
data.stream = True
assert isinstance(stream_inferface, AgentChunkStreamingInterface) or isinstance(
stream_inferface, AgentRefreshStreamingInterface
), type(stream_inferface)
return openai_chat_completions_process_stream(
url=agent_state.llm_config.model_endpoint, # https://api.openai.com/v1 -> https://api.openai.com/v1/chat/completions
api_key=credentials.openai_key,
chat_completion_request=data,
stream_inferface=stream_inferface,
)
else:
data.stream = False
return openai_chat_completions_request(
url=agent_state.llm_config.model_endpoint, # https://api.openai.com/v1 -> https://api.openai.com/v1/chat/completions
api_key=credentials.openai_key,
chat_completion_request=data,
)

# azure
elif agent_state.llm_config.model_endpoint_type == "azure":
Expand Down
Loading
Loading