Skip to content

Commit

Permalink
feat: add Preset routes to API + patch for tool_call_id max lengt…
Browse files Browse the repository at this point in the history
…h OpenAI error (#1165)
  • Loading branch information
cpacker authored Mar 21, 2024
1 parent 20db9c0 commit 99b6dbc
Show file tree
Hide file tree
Showing 11 changed files with 272 additions and 18 deletions.
46 changes: 41 additions & 5 deletions memgpt/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Dict, List, Union, Optional, Tuple

from memgpt.data_types import AgentState, User, Preset, LLMConfig, EmbeddingConfig, Source
from memgpt.models.pydantic_models import HumanModel, PersonaModel
from memgpt.models.pydantic_models import HumanModel, PersonaModel, PresetModel
from memgpt.cli.cli import QuickstartChoice
from memgpt.cli.cli import set_config_with_dict, quickstart as quickstart_func, str_to_quickstart_choice
from memgpt.config import MemGPTConfig
Expand All @@ -30,6 +30,7 @@
from memgpt.server.rest_api.personas.index import ListPersonasResponse
from memgpt.server.rest_api.tools.index import ListToolsResponse, CreateToolResponse
from memgpt.server.rest_api.models.index import ListModelsResponse
from memgpt.server.rest_api.presets.index import CreatePresetResponse, CreatePresetsRequest, ListPresetsResponse


def create_client(base_url: Optional[str] = None, token: Optional[str] = None):
Expand Down Expand Up @@ -85,6 +86,12 @@ def get_agent(self, agent_id: Optional[str] = None, agent_name: Optional[str] =
def create_preset(self, preset: Preset):
raise NotImplementedError

def delete_preset(self, preset_id: uuid.UUID):
raise NotImplementedError

def list_presets(self):
raise NotImplementedError

# memory

def get_agent_memory(self, agent_id: str) -> Dict:
Expand Down Expand Up @@ -300,11 +307,32 @@ def get_agent(self, agent_id: Optional[str] = None, agent_name: Optional[str] =
return self.get_agent_response_to_state(response_obj)

# presets
def create_preset(self, preset: Preset):
raise NotImplementedError
def create_preset(self, preset: Preset) -> CreatePresetResponse:
# TODO should the arg type here be PresetModel, not Preset?
payload = CreatePresetsRequest(
id=str(preset.id),
name=preset.name,
description=preset.description,
system=preset.system,
persona=preset.persona,
human=preset.human,
persona_name=preset.persona_name,
human_name=preset.human_name,
functions_schema=preset.functions_schema,
)
response = requests.post(f"{self.base_url}/api/presets", json=payload.model_dump(), headers=self.headers)
assert response.status_code == 200, f"Failed to create preset: {response.text}"
return CreatePresetResponse(**response.json())

# memory
def delete_preset(self, preset_id: uuid.UUID):
response = requests.delete(f"{self.base_url}/api/presets/{str(preset_id)}", headers=self.headers)
assert response.status_code == 200, f"Failed to delete preset: {response.text}"

def list_presets(self) -> List[PresetModel]:
response = requests.get(f"{self.base_url}/api/presets", headers=self.headers)
return ListPresetsResponse(**response.json()).presets

# memory
def get_agent_memory(self, agent_id: uuid.UUID) -> GetAgentMemoryResponse:
response = requests.get(f"{self.base_url}/api/agents/{agent_id}/memory", headers=self.headers)
return GetAgentMemoryResponse(**response.json())
Expand Down Expand Up @@ -542,10 +570,18 @@ def create_agent(
)
return agent_state

def create_preset(self, preset: Preset):
def create_preset(self, preset: Preset) -> Preset:
if preset.user_id is None:
preset.user_id = self.user_id
preset = self.server.create_preset(preset=preset)
return preset

def delete_preset(self, preset_id: uuid.UUID):
preset = self.server.delete_preset(preset_id=preset_id, user_id=self.user_id)

def list_presets(self) -> List[PresetModel]:
return self.server.list_presets(user_id=self.user_id)

def get_agent_config(self, agent_id: str) -> AgentState:
self.interface.clear()
return self.server.get_agent_config(user_id=self.user_id, agent_id=agent_id)
Expand Down
3 changes: 3 additions & 0 deletions memgpt/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@

MEMGPT_DIR = os.path.join(os.path.expanduser("~"), ".memgpt")

# OpenAI error message: Invalid 'messages[1].tool_calls[0].id': string too long. Expected a string with maximum length 29, but got a string with length 36 instead.
TOOL_CALL_ID_MAX_LEN = 29

# embeddings
MAX_EMBEDDING_DIM = 4096 # maximum supported embeding size - do NOT change or else DBs will need to be reset

Expand Down
19 changes: 15 additions & 4 deletions memgpt/data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,15 @@
from typing import Optional, List, Dict, TypeVar
import numpy as np

from memgpt.constants import DEFAULT_HUMAN, DEFAULT_MEMGPT_MODEL, DEFAULT_PERSONA, DEFAULT_PRESET, LLM_MAX_TOKENS, MAX_EMBEDDING_DIM
from memgpt.constants import (
DEFAULT_HUMAN,
DEFAULT_MEMGPT_MODEL,
DEFAULT_PERSONA,
DEFAULT_PRESET,
LLM_MAX_TOKENS,
MAX_EMBEDDING_DIM,
TOOL_CALL_ID_MAX_LEN,
)
from memgpt.utils import get_local_time, format_datetime, get_utc_time, create_uuid_from_string
from memgpt.models import chat_completion_response
from memgpt.utils import get_human_text, get_persona_text, printd
Expand Down Expand Up @@ -229,7 +237,7 @@ def dict_to_message(
tool_call_id=openai_message_dict["tool_call_id"] if "tool_call_id" in openai_message_dict else None,
)

def to_openai_dict(self):
def to_openai_dict(self, max_tool_id_length=TOOL_CALL_ID_MAX_LEN):
"""Go from Message class to ChatCompletion message object"""

# TODO change to pydantic casting, eg `return SystemMessageModel(self)`
Expand Down Expand Up @@ -265,13 +273,16 @@ def to_openai_dict(self):
openai_message["name"] = self.name
if self.tool_calls is not None:
openai_message["tool_calls"] = [tool_call.to_dict() for tool_call in self.tool_calls]
if max_tool_id_length:
for tool_call_dict in openai_message["tool_calls"]:
tool_call_dict["id"] = tool_call_dict["id"][:max_tool_id_length]

elif self.role == "tool":
assert all([v is not None for v in [self.role, self.tool_call_id]]), vars(self)
openai_message = {
"content": self.text,
"role": self.role,
"tool_call_id": self.tool_call_id,
"tool_call_id": self.tool_call_id[:max_tool_id_length] if max_tool_id_length else self.tool_call_id,
}
else:
raise ValueError(self.role)
Expand Down Expand Up @@ -540,7 +551,7 @@ def __init__(
class Preset(BaseModel):
name: str = Field(..., description="The name of the preset.")
id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the preset.")
user_id: uuid.UUID = Field(..., description="The unique identifier of the user who created the preset.")
user_id: Optional[uuid.UUID] = Field(None, description="The unique identifier of the user who created the preset.")
description: Optional[str] = Field(None, description="The description of the preset.")
created_at: datetime = Field(default_factory=datetime.now, description="The unix timestamp of when the preset was created.")
system: str = Field(..., description="The system prompt of the preset.")
Expand Down
4 changes: 3 additions & 1 deletion memgpt/models/pydantic_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,14 @@ class EmbeddingConfigModel(BaseModel):
class PresetModel(BaseModel):
name: str = Field(..., description="The name of the preset.")
id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the preset.")
user_id: uuid.UUID = Field(..., description="The unique identifier of the user who created the preset.")
user_id: Optional[uuid.UUID] = Field(None, description="The unique identifier of the user who created the preset.")
description: Optional[str] = Field(None, description="The description of the preset.")
created_at: datetime = Field(default_factory=datetime.now, description="The unix timestamp of when the preset was created.")
system: str = Field(..., description="The system prompt of the preset.")
persona: str = Field(default=get_persona_text(DEFAULT_PERSONA), description="The persona of the preset.")
persona_name: Optional[str] = Field(None, description="The name of the persona of the preset.")
human: str = Field(default=get_human_text(DEFAULT_HUMAN), description="The human of the preset.")
human_name: Optional[str] = Field(None, description="The name of the human of the preset.")
functions_schema: List[Dict] = Field(..., description="The functions schema of the preset.")


Expand Down
Empty file.
121 changes: 121 additions & 0 deletions memgpt/server/rest_api/presets/index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import uuid
from functools import partial
from typing import List, Optional, Dict, Union

from fastapi import APIRouter, Body, Depends, Query, HTTPException, status
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field

from memgpt.data_types import Preset # TODO remove
from memgpt.models.pydantic_models import PresetModel
from memgpt.server.rest_api.auth_token import get_current_user
from memgpt.server.rest_api.interface import QueuingInterface
from memgpt.server.server import SyncServer
from memgpt.constants import DEFAULT_HUMAN, DEFAULT_PERSONA
from memgpt.utils import get_human_text, get_persona_text

router = APIRouter()

"""
Implement the following functions:
* List all available presets
* Create a new preset
* Delete a preset
* TODO update a preset
"""


class ListPresetsResponse(BaseModel):
presets: List[PresetModel] = Field(..., description="List of available presets.")


class CreatePresetsRequest(BaseModel):
# TODO is there a cleaner way to create the request from the PresetModel (need to drop fields though)?
name: str = Field(..., description="The name of the preset.")
id: Optional[Union[uuid.UUID, str]] = Field(default_factory=uuid.uuid4, description="The unique identifier of the preset.")
# user_id: uuid.UUID = Field(..., description="The unique identifier of the user who created the preset.")
description: Optional[str] = Field(None, description="The description of the preset.")
# created_at: datetime = Field(default_factory=datetime.now, description="The unix timestamp of when the preset was created.")
system: str = Field(..., description="The system prompt of the preset.")
persona: str = Field(default=get_persona_text(DEFAULT_PERSONA), description="The persona of the preset.")
human: str = Field(default=get_human_text(DEFAULT_HUMAN), description="The human of the preset.")
functions_schema: List[Dict] = Field(..., description="The functions schema of the preset.")
# TODO
persona_name: Optional[str] = Field(None, description="The name of the persona of the preset.")
human_name: Optional[str] = Field(None, description="The name of the human of the preset.")


class CreatePresetResponse(BaseModel):
preset: PresetModel = Field(..., description="The newly created preset.")


def setup_presets_index_router(server: SyncServer, interface: QueuingInterface, password: str):
get_current_user_with_server = partial(partial(get_current_user, server), password)

@router.get("/presets", tags=["presets"], response_model=ListPresetsResponse)
async def list_presets(
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""List all presets created by a user."""
# Clear the interface
interface.clear()

try:
presets = server.list_presets(user_id=user_id)
return ListPresetsResponse(presets=presets)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")

@router.post("/presets", tags=["presets"], response_model=CreatePresetResponse)
async def create_preset(
request: CreatePresetsRequest = Body(...),
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""Create a preset."""
try:
if isinstance(request.id, str):
request.id = uuid.UUID(request.id)
# new_preset = PresetModel(
new_preset = Preset(
user_id=user_id,
id=request.id,
name=request.name,
description=request.description,
system=request.system,
persona=request.persona,
human=request.human,
functions_schema=request.functions_schema,
persona_name=request.persona_name,
human_name=request.human_name,
)
preset = server.create_preset(preset=new_preset)

# TODO remove once we migrate from Preset to PresetModel
preset = PresetModel(**vars(preset))

return CreatePresetResponse(preset=preset)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")

@router.delete("/presets/{preset_id}", tags=["presets"])
async def delete_preset(
preset_id: uuid.UUID,
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""Delete a preset."""
interface.clear()
try:
preset = server.delete_preset(user_id=user_id, preset_id=preset_id)
return JSONResponse(
status_code=status.HTTP_200_OK, content={"message": f"Preset preset_id={str(preset.id)} successfully deleted"}
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")

return router
2 changes: 2 additions & 0 deletions memgpt/server/rest_api/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from memgpt.server.rest_api.static_files import mount_static_files
from memgpt.server.rest_api.tools.index import setup_tools_index_router
from memgpt.server.rest_api.sources.index import setup_sources_index_router
from memgpt.server.rest_api.presets.index import setup_presets_index_router
from memgpt.server.server import SyncServer
from memgpt.config import MemGPTConfig
from memgpt.server.constants import REST_DEFAULT_PORT
Expand Down Expand Up @@ -102,6 +103,7 @@ def verify_password(credentials: HTTPAuthorizationCredentials = Depends(security
app.include_router(setup_models_index_router(server, interface, password), prefix=API_PREFIX)
app.include_router(setup_tools_index_router(server, interface, password), prefix=API_PREFIX)
app.include_router(setup_sources_index_router(server, interface, password), prefix=API_PREFIX)
app.include_router(setup_presets_index_router(server, interface, password), prefix=API_PREFIX)

# /api/config endpoints
app.include_router(setup_config_index_router(server, interface, password), prefix=API_PREFIX)
Expand Down
15 changes: 10 additions & 5 deletions memgpt/server/rest_api/sources/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,20 @@ def setup_sources_index_router(server: SyncServer, interface: QueuingInterface,
get_current_user_with_server = partial(partial(get_current_user, server), password)

@router.get("/sources", tags=["sources"], response_model=ListSourcesResponse)
async def list_source(
async def list_sources(
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""List all data sources created by a user."""
# Clear the interface
interface.clear()

sources = server.list_all_sources(user_id=user_id)
return ListSourcesResponse(sources=sources)
try:
sources = server.list_all_sources(user_id=user_id)
return ListSourcesResponse(sources=sources)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")

@router.post("/sources", tags=["sources"], response_model=SourceModel)
async def create_source(
Expand All @@ -100,13 +105,13 @@ async def create_source(

@router.delete("/sources/{source_id}", tags=["sources"])
async def delete_source(
source_id,
source_id: uuid.UUID,
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""Delete a data source."""
interface.clear()
try:
server.delete_source(source_id=uuid.UUID(source_id), user_id=user_id)
server.delete_source(source_id=source_id, user_id=user_id)
return JSONResponse(status_code=status.HTTP_200_OK, content={"message": f"Source source_id={source_id} successfully deleted"})
except HTTPException:
raise
Expand Down
25 changes: 23 additions & 2 deletions memgpt/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
Preset,
)

from memgpt.models.pydantic_models import SourceModel, PassageModel, DocumentModel
from memgpt.models.pydantic_models import SourceModel, PassageModel, DocumentModel, PresetModel
from memgpt.interface import AgentInterface # abstract

# TODO use custom interface
Expand Down Expand Up @@ -751,13 +751,27 @@ def delete_agent(
if agent is not None:
self.ms.delete_agent(agent_id=agent_id)

def delete_preset(self, user_id: uuid.UUID, preset_id: uuid.UUID) -> Preset:
if self.ms.get_user(user_id=user_id) is None:
raise ValueError(f"User user_id={user_id} does not exist")

# first get the preset by name
preset = self.get_preset(preset_id=preset_id, user_id=user_id)
if preset is None:
raise ValueError(f"Could not find preset_id {preset_id}")
# then delete via name
# TODO allow delete-by-id, eg via server.delete_preset function
self.ms.delete_preset(name=preset.name, user_id=user_id)

return preset

def initialize_default_presets(self, user_id: uuid.UUID):
"""Add default preset options into the metadata store"""
presets.add_default_presets(user_id, self.ms)

def create_preset(self, preset: Preset):
"""Create a new preset using a config"""
if self.ms.get_user(user_id=preset.user_id) is None:
if preset.user_id is not None and self.ms.get_user(user_id=preset.user_id) is None:
raise ValueError(f"User user_id={preset.user_id} does not exist")

self.ms.create_preset(preset)
Expand All @@ -769,6 +783,13 @@ def get_preset(
"""Get the preset"""
return self.ms.get_preset(preset_id=preset_id, name=preset_name, user_id=user_id)

def list_presets(self, user_id: uuid.UUID) -> List[PresetModel]:
# TODO update once we strip Preset in favor of PresetModel
presets = self.ms.list_presets(user_id=user_id)
presets = [PresetModel(**vars(p)) for p in presets]

return presets

def _agent_state_to_config(self, agent_state: AgentState) -> dict:
"""Convert AgentState to a dict for a JSON response"""
assert agent_state is not None
Expand Down
Loading

0 comments on commit 99b6dbc

Please sign in to comment.