Skip to content

Commit

Permalink
feat: add streaming support for OpenAI-compatible endpoints (#1262)
Browse files Browse the repository at this point in the history
  • Loading branch information
cpacker authored Apr 18, 2024
1 parent e22f357 commit aeb4a94
Show file tree
Hide file tree
Showing 11 changed files with 997 additions and 58 deletions.
7 changes: 7 additions & 0 deletions memgpt/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,7 @@ def _get_ai_reply(
message_sequence: List[Message],
function_call: str = "auto",
first_message: bool = False, # hint
stream: bool = False, # TODO move to config?
) -> chat_completion_response.ChatCompletionResponse:
"""Get response from LLM API"""
try:
Expand All @@ -414,6 +415,9 @@ def _get_ai_reply(
function_call=function_call,
# hint
first_message=first_message,
# streaming
stream=stream,
stream_inferface=self.interface,
)
# special case for 'length'
if response.choices[0].finish_reason == "length":
Expand Down Expand Up @@ -628,6 +632,7 @@ def step(
skip_verify: bool = False,
return_dicts: bool = True, # if True, return dicts, if False, return Message objects
recreate_message_timestamp: bool = True, # if True, when input is a Message type, recreated the 'created_at' field
stream: bool = False, # TODO move to config?
) -> Tuple[List[Union[dict, Message]], bool, bool, bool]:
"""Top-level event message handler for the MemGPT agent"""

Expand Down Expand Up @@ -710,6 +715,7 @@ def validate_json(user_message_text: str, raise_on_error: bool) -> str:
response = self._get_ai_reply(
message_sequence=input_message_sequence,
first_message=True, # passed through to the prompt formatter
stream=stream,
)
if verify_first_message_correctness(response, require_monologue=self.first_message_verify_mono):
break
Expand All @@ -721,6 +727,7 @@ def validate_json(user_message_text: str, raise_on_error: bool) -> str:
else:
response = self._get_ai_reply(
message_sequence=input_message_sequence,
stream=stream,
)

# Step 2: check if LLM wanted to call a function
Expand Down
10 changes: 8 additions & 2 deletions memgpt/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
import questionary

from memgpt.log import logger
from memgpt.interface import CLIInterface as interface # for printing to terminal

# from memgpt.interface import CLIInterface as interface # for printing to terminal
from memgpt.streaming_interface import StreamingRefreshCLIInterface as interface # for printing to terminal
from memgpt.cli.cli_config import configure
import memgpt.presets.presets as presets
import memgpt.utils as utils
Expand Down Expand Up @@ -445,6 +447,8 @@ def run(
debug: Annotated[bool, typer.Option(help="Use --debug to enable debugging output")] = False,
no_verify: Annotated[bool, typer.Option(help="Bypass message verification")] = False,
yes: Annotated[bool, typer.Option("-y", help="Skip confirmation prompt and use defaults")] = False,
# streaming
stream: Annotated[bool, typer.Option(help="Enables message streaming in the CLI (if the backend supports it)")] = False,
):
"""Start chatting with an MemGPT agent
Expand Down Expand Up @@ -710,7 +714,9 @@ def run(
from memgpt.main import run_agent_loop

print() # extra space
run_agent_loop(memgpt_agent, config, first, ms, no_verify) # TODO: add back no_verify
run_agent_loop(
memgpt_agent=memgpt_agent, config=config, first=first, ms=ms, no_verify=no_verify, stream=stream
) # TODO: add back no_verify


def delete_agent(
Expand Down
3 changes: 2 additions & 1 deletion memgpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def generate_uuid() -> str:
def load(cls) -> "MemGPTConfig":
# avoid circular import
from memgpt.migrate import config_is_compatible, VERSION_CUTOFF
from memgpt.utils import printd

if not config_is_compatible(allow_empty=True):
error_message = " ".join(
Expand All @@ -110,7 +111,7 @@ def load(cls) -> "MemGPTConfig":

# insure all configuration directories exist
cls.create_config_dir()
print(f"Loading config from {config_path}")
printd(f"Loading config from {config_path}")
if os.path.exists(config_path):
# read existing config
config.read(config_path)
Expand Down
42 changes: 30 additions & 12 deletions memgpt/llm_api/llm_api_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,18 @@
import requests
import os
import time
from typing import List
from typing import List, Optional, Union

from memgpt.credentials import MemGPTCredentials
from memgpt.local_llm.chat_completion_proxy import get_chat_completion
from memgpt.constants import CLI_WARNING_PREFIX
from memgpt.models.chat_completion_response import ChatCompletionResponse
from memgpt.models.chat_completion_request import ChatCompletionRequest, Tool, cast_message_to_subtype
from memgpt.streaming_interface import AgentChunkStreamingInterface, AgentRefreshStreamingInterface

from memgpt.data_types import AgentState, Message

from memgpt.llm_api.openai import openai_chat_completions_request
from memgpt.llm_api.openai import openai_chat_completions_request, openai_chat_completions_process_stream
from memgpt.llm_api.azure_openai import azure_openai_chat_completions_request, MODEL_TO_AZURE_ENGINE
from memgpt.llm_api.google_ai import (
google_ai_chat_completions_request,
Expand Down Expand Up @@ -126,14 +127,17 @@ def wrapper(*args, **kwargs):
def create(
agent_state: AgentState,
messages: List[Message],
functions=None,
functions_python=None,
function_call="auto",
functions: list = None,
functions_python: list = None,
function_call: str = "auto",
# hint
first_message=False,
first_message: bool = False,
# use tool naming?
# if false, will use deprecated 'functions' style
use_tool_naming=True,
use_tool_naming: bool = True,
# streaming?
stream: bool = False,
stream_inferface: Optional[Union[AgentRefreshStreamingInterface, AgentChunkStreamingInterface]] = None,
) -> ChatCompletionResponse:
"""Return response to chat completion with backoff"""
from memgpt.utils import printd
Expand Down Expand Up @@ -169,11 +173,25 @@ def create(
function_call=function_call,
user=str(agent_state.user_id),
)
return openai_chat_completions_request(
url=agent_state.llm_config.model_endpoint, # https://api.openai.com/v1 -> https://api.openai.com/v1/chat/completions
api_key=credentials.openai_key,
data=data,
)

if stream:
data.stream = True
assert isinstance(stream_inferface, AgentChunkStreamingInterface) or isinstance(
stream_inferface, AgentRefreshStreamingInterface
), type(stream_inferface)
return openai_chat_completions_process_stream(
url=agent_state.llm_config.model_endpoint, # https://api.openai.com/v1 -> https://api.openai.com/v1/chat/completions
api_key=credentials.openai_key,
chat_completion_request=data,
stream_inferface=stream_inferface,
)
else:
data.stream = False
return openai_chat_completions_request(
url=agent_state.llm_config.model_endpoint, # https://api.openai.com/v1 -> https://api.openai.com/v1/chat/completions
api_key=credentials.openai_key,
chat_completion_request=data,
)

# azure
elif agent_state.llm_config.model_endpoint_type == "azure":
Expand Down
Loading

0 comments on commit aeb4a94

Please sign in to comment.