From 99b6dbca2e9e61c58728fad90d071909ddd96262 Mon Sep 17 00:00:00 2001 From: Charles Packer Date: Wed, 20 Mar 2024 17:05:06 -0700 Subject: [PATCH] feat: add `Preset` routes to API + patch for `tool_call_id` max length OpenAI error (#1165) --- memgpt/client/client.py | 46 +++++++- memgpt/constants.py | 3 + memgpt/data_types.py | 19 +++- memgpt/models/pydantic_models.py | 4 +- memgpt/server/rest_api/presets/__init__.py | 0 memgpt/server/rest_api/presets/index.py | 121 +++++++++++++++++++++ memgpt/server/rest_api/server.py | 2 + memgpt/server/rest_api/sources/index.py | 15 ++- memgpt/server/server.py | 25 ++++- memgpt/utils.py | 3 +- tests/test_client.py | 52 +++++++++ 11 files changed, 272 insertions(+), 18 deletions(-) create mode 100644 memgpt/server/rest_api/presets/__init__.py create mode 100644 memgpt/server/rest_api/presets/index.py diff --git a/memgpt/client/client.py b/memgpt/client/client.py index 71c430b643..61b9156eb2 100644 --- a/memgpt/client/client.py +++ b/memgpt/client/client.py @@ -5,7 +5,7 @@ from typing import Dict, List, Union, Optional, Tuple from memgpt.data_types import AgentState, User, Preset, LLMConfig, EmbeddingConfig, Source -from memgpt.models.pydantic_models import HumanModel, PersonaModel +from memgpt.models.pydantic_models import HumanModel, PersonaModel, PresetModel from memgpt.cli.cli import QuickstartChoice from memgpt.cli.cli import set_config_with_dict, quickstart as quickstart_func, str_to_quickstart_choice from memgpt.config import MemGPTConfig @@ -30,6 +30,7 @@ from memgpt.server.rest_api.personas.index import ListPersonasResponse from memgpt.server.rest_api.tools.index import ListToolsResponse, CreateToolResponse from memgpt.server.rest_api.models.index import ListModelsResponse +from memgpt.server.rest_api.presets.index import CreatePresetResponse, CreatePresetsRequest, ListPresetsResponse def create_client(base_url: Optional[str] = None, token: Optional[str] = None): @@ -85,6 +86,12 @@ def get_agent(self, agent_id: Optional[str] = None, agent_name: Optional[str] = def create_preset(self, preset: Preset): raise NotImplementedError + def delete_preset(self, preset_id: uuid.UUID): + raise NotImplementedError + + def list_presets(self): + raise NotImplementedError + # memory def get_agent_memory(self, agent_id: str) -> Dict: @@ -300,11 +307,32 @@ def get_agent(self, agent_id: Optional[str] = None, agent_name: Optional[str] = return self.get_agent_response_to_state(response_obj) # presets - def create_preset(self, preset: Preset): - raise NotImplementedError + def create_preset(self, preset: Preset) -> CreatePresetResponse: + # TODO should the arg type here be PresetModel, not Preset? + payload = CreatePresetsRequest( + id=str(preset.id), + name=preset.name, + description=preset.description, + system=preset.system, + persona=preset.persona, + human=preset.human, + persona_name=preset.persona_name, + human_name=preset.human_name, + functions_schema=preset.functions_schema, + ) + response = requests.post(f"{self.base_url}/api/presets", json=payload.model_dump(), headers=self.headers) + assert response.status_code == 200, f"Failed to create preset: {response.text}" + return CreatePresetResponse(**response.json()) - # memory + def delete_preset(self, preset_id: uuid.UUID): + response = requests.delete(f"{self.base_url}/api/presets/{str(preset_id)}", headers=self.headers) + assert response.status_code == 200, f"Failed to delete preset: {response.text}" + + def list_presets(self) -> List[PresetModel]: + response = requests.get(f"{self.base_url}/api/presets", headers=self.headers) + return ListPresetsResponse(**response.json()).presets + # memory def get_agent_memory(self, agent_id: uuid.UUID) -> GetAgentMemoryResponse: response = requests.get(f"{self.base_url}/api/agents/{agent_id}/memory", headers=self.headers) return GetAgentMemoryResponse(**response.json()) @@ -542,10 +570,18 @@ def create_agent( ) return agent_state - def create_preset(self, preset: Preset): + def create_preset(self, preset: Preset) -> Preset: + if preset.user_id is None: + preset.user_id = self.user_id preset = self.server.create_preset(preset=preset) return preset + def delete_preset(self, preset_id: uuid.UUID): + preset = self.server.delete_preset(preset_id=preset_id, user_id=self.user_id) + + def list_presets(self) -> List[PresetModel]: + return self.server.list_presets(user_id=self.user_id) + def get_agent_config(self, agent_id: str) -> AgentState: self.interface.clear() return self.server.get_agent_config(user_id=self.user_id, agent_id=agent_id) diff --git a/memgpt/constants.py b/memgpt/constants.py index 709d3fdbfd..55e82d30c3 100644 --- a/memgpt/constants.py +++ b/memgpt/constants.py @@ -3,6 +3,9 @@ MEMGPT_DIR = os.path.join(os.path.expanduser("~"), ".memgpt") +# OpenAI error message: Invalid 'messages[1].tool_calls[0].id': string too long. Expected a string with maximum length 29, but got a string with length 36 instead. +TOOL_CALL_ID_MAX_LEN = 29 + # embeddings MAX_EMBEDDING_DIM = 4096 # maximum supported embeding size - do NOT change or else DBs will need to be reset diff --git a/memgpt/data_types.py b/memgpt/data_types.py index c3cb219ea4..71c9b4cf2f 100644 --- a/memgpt/data_types.py +++ b/memgpt/data_types.py @@ -5,7 +5,15 @@ from typing import Optional, List, Dict, TypeVar import numpy as np -from memgpt.constants import DEFAULT_HUMAN, DEFAULT_MEMGPT_MODEL, DEFAULT_PERSONA, DEFAULT_PRESET, LLM_MAX_TOKENS, MAX_EMBEDDING_DIM +from memgpt.constants import ( + DEFAULT_HUMAN, + DEFAULT_MEMGPT_MODEL, + DEFAULT_PERSONA, + DEFAULT_PRESET, + LLM_MAX_TOKENS, + MAX_EMBEDDING_DIM, + TOOL_CALL_ID_MAX_LEN, +) from memgpt.utils import get_local_time, format_datetime, get_utc_time, create_uuid_from_string from memgpt.models import chat_completion_response from memgpt.utils import get_human_text, get_persona_text, printd @@ -229,7 +237,7 @@ def dict_to_message( tool_call_id=openai_message_dict["tool_call_id"] if "tool_call_id" in openai_message_dict else None, ) - def to_openai_dict(self): + def to_openai_dict(self, max_tool_id_length=TOOL_CALL_ID_MAX_LEN): """Go from Message class to ChatCompletion message object""" # TODO change to pydantic casting, eg `return SystemMessageModel(self)` @@ -265,13 +273,16 @@ def to_openai_dict(self): openai_message["name"] = self.name if self.tool_calls is not None: openai_message["tool_calls"] = [tool_call.to_dict() for tool_call in self.tool_calls] + if max_tool_id_length: + for tool_call_dict in openai_message["tool_calls"]: + tool_call_dict["id"] = tool_call_dict["id"][:max_tool_id_length] elif self.role == "tool": assert all([v is not None for v in [self.role, self.tool_call_id]]), vars(self) openai_message = { "content": self.text, "role": self.role, - "tool_call_id": self.tool_call_id, + "tool_call_id": self.tool_call_id[:max_tool_id_length] if max_tool_id_length else self.tool_call_id, } else: raise ValueError(self.role) @@ -540,7 +551,7 @@ def __init__( class Preset(BaseModel): name: str = Field(..., description="The name of the preset.") id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the preset.") - user_id: uuid.UUID = Field(..., description="The unique identifier of the user who created the preset.") + user_id: Optional[uuid.UUID] = Field(None, description="The unique identifier of the user who created the preset.") description: Optional[str] = Field(None, description="The description of the preset.") created_at: datetime = Field(default_factory=datetime.now, description="The unix timestamp of when the preset was created.") system: str = Field(..., description="The system prompt of the preset.") diff --git a/memgpt/models/pydantic_models.py b/memgpt/models/pydantic_models.py index a2c098df83..bb582b7cef 100644 --- a/memgpt/models/pydantic_models.py +++ b/memgpt/models/pydantic_models.py @@ -33,12 +33,14 @@ class EmbeddingConfigModel(BaseModel): class PresetModel(BaseModel): name: str = Field(..., description="The name of the preset.") id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the preset.") - user_id: uuid.UUID = Field(..., description="The unique identifier of the user who created the preset.") + user_id: Optional[uuid.UUID] = Field(None, description="The unique identifier of the user who created the preset.") description: Optional[str] = Field(None, description="The description of the preset.") created_at: datetime = Field(default_factory=datetime.now, description="The unix timestamp of when the preset was created.") system: str = Field(..., description="The system prompt of the preset.") persona: str = Field(default=get_persona_text(DEFAULT_PERSONA), description="The persona of the preset.") + persona_name: Optional[str] = Field(None, description="The name of the persona of the preset.") human: str = Field(default=get_human_text(DEFAULT_HUMAN), description="The human of the preset.") + human_name: Optional[str] = Field(None, description="The name of the human of the preset.") functions_schema: List[Dict] = Field(..., description="The functions schema of the preset.") diff --git a/memgpt/server/rest_api/presets/__init__.py b/memgpt/server/rest_api/presets/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/memgpt/server/rest_api/presets/index.py b/memgpt/server/rest_api/presets/index.py new file mode 100644 index 0000000000..b5c0055493 --- /dev/null +++ b/memgpt/server/rest_api/presets/index.py @@ -0,0 +1,121 @@ +import uuid +from functools import partial +from typing import List, Optional, Dict, Union + +from fastapi import APIRouter, Body, Depends, Query, HTTPException, status +from fastapi.responses import JSONResponse +from pydantic import BaseModel, Field + +from memgpt.data_types import Preset # TODO remove +from memgpt.models.pydantic_models import PresetModel +from memgpt.server.rest_api.auth_token import get_current_user +from memgpt.server.rest_api.interface import QueuingInterface +from memgpt.server.server import SyncServer +from memgpt.constants import DEFAULT_HUMAN, DEFAULT_PERSONA +from memgpt.utils import get_human_text, get_persona_text + +router = APIRouter() + +""" +Implement the following functions: +* List all available presets +* Create a new preset +* Delete a preset +* TODO update a preset +""" + + +class ListPresetsResponse(BaseModel): + presets: List[PresetModel] = Field(..., description="List of available presets.") + + +class CreatePresetsRequest(BaseModel): + # TODO is there a cleaner way to create the request from the PresetModel (need to drop fields though)? + name: str = Field(..., description="The name of the preset.") + id: Optional[Union[uuid.UUID, str]] = Field(default_factory=uuid.uuid4, description="The unique identifier of the preset.") + # user_id: uuid.UUID = Field(..., description="The unique identifier of the user who created the preset.") + description: Optional[str] = Field(None, description="The description of the preset.") + # created_at: datetime = Field(default_factory=datetime.now, description="The unix timestamp of when the preset was created.") + system: str = Field(..., description="The system prompt of the preset.") + persona: str = Field(default=get_persona_text(DEFAULT_PERSONA), description="The persona of the preset.") + human: str = Field(default=get_human_text(DEFAULT_HUMAN), description="The human of the preset.") + functions_schema: List[Dict] = Field(..., description="The functions schema of the preset.") + # TODO + persona_name: Optional[str] = Field(None, description="The name of the persona of the preset.") + human_name: Optional[str] = Field(None, description="The name of the human of the preset.") + + +class CreatePresetResponse(BaseModel): + preset: PresetModel = Field(..., description="The newly created preset.") + + +def setup_presets_index_router(server: SyncServer, interface: QueuingInterface, password: str): + get_current_user_with_server = partial(partial(get_current_user, server), password) + + @router.get("/presets", tags=["presets"], response_model=ListPresetsResponse) + async def list_presets( + user_id: uuid.UUID = Depends(get_current_user_with_server), + ): + """List all presets created by a user.""" + # Clear the interface + interface.clear() + + try: + presets = server.list_presets(user_id=user_id) + return ListPresetsResponse(presets=presets) + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"{e}") + + @router.post("/presets", tags=["presets"], response_model=CreatePresetResponse) + async def create_preset( + request: CreatePresetsRequest = Body(...), + user_id: uuid.UUID = Depends(get_current_user_with_server), + ): + """Create a preset.""" + try: + if isinstance(request.id, str): + request.id = uuid.UUID(request.id) + # new_preset = PresetModel( + new_preset = Preset( + user_id=user_id, + id=request.id, + name=request.name, + description=request.description, + system=request.system, + persona=request.persona, + human=request.human, + functions_schema=request.functions_schema, + persona_name=request.persona_name, + human_name=request.human_name, + ) + preset = server.create_preset(preset=new_preset) + + # TODO remove once we migrate from Preset to PresetModel + preset = PresetModel(**vars(preset)) + + return CreatePresetResponse(preset=preset) + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"{e}") + + @router.delete("/presets/{preset_id}", tags=["presets"]) + async def delete_preset( + preset_id: uuid.UUID, + user_id: uuid.UUID = Depends(get_current_user_with_server), + ): + """Delete a preset.""" + interface.clear() + try: + preset = server.delete_preset(user_id=user_id, preset_id=preset_id) + return JSONResponse( + status_code=status.HTTP_200_OK, content={"message": f"Preset preset_id={str(preset.id)} successfully deleted"} + ) + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"{e}") + + return router diff --git a/memgpt/server/rest_api/server.py b/memgpt/server/rest_api/server.py index 88d80e79af..6149c41bb6 100644 --- a/memgpt/server/rest_api/server.py +++ b/memgpt/server/rest_api/server.py @@ -25,6 +25,7 @@ from memgpt.server.rest_api.static_files import mount_static_files from memgpt.server.rest_api.tools.index import setup_tools_index_router from memgpt.server.rest_api.sources.index import setup_sources_index_router +from memgpt.server.rest_api.presets.index import setup_presets_index_router from memgpt.server.server import SyncServer from memgpt.config import MemGPTConfig from memgpt.server.constants import REST_DEFAULT_PORT @@ -102,6 +103,7 @@ def verify_password(credentials: HTTPAuthorizationCredentials = Depends(security app.include_router(setup_models_index_router(server, interface, password), prefix=API_PREFIX) app.include_router(setup_tools_index_router(server, interface, password), prefix=API_PREFIX) app.include_router(setup_sources_index_router(server, interface, password), prefix=API_PREFIX) +app.include_router(setup_presets_index_router(server, interface, password), prefix=API_PREFIX) # /api/config endpoints app.include_router(setup_config_index_router(server, interface, password), prefix=API_PREFIX) diff --git a/memgpt/server/rest_api/sources/index.py b/memgpt/server/rest_api/sources/index.py index d8ff69115d..d7e8f69ebe 100644 --- a/memgpt/server/rest_api/sources/index.py +++ b/memgpt/server/rest_api/sources/index.py @@ -65,15 +65,20 @@ def setup_sources_index_router(server: SyncServer, interface: QueuingInterface, get_current_user_with_server = partial(partial(get_current_user, server), password) @router.get("/sources", tags=["sources"], response_model=ListSourcesResponse) - async def list_source( + async def list_sources( user_id: uuid.UUID = Depends(get_current_user_with_server), ): """List all data sources created by a user.""" # Clear the interface interface.clear() - sources = server.list_all_sources(user_id=user_id) - return ListSourcesResponse(sources=sources) + try: + sources = server.list_all_sources(user_id=user_id) + return ListSourcesResponse(sources=sources) + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"{e}") @router.post("/sources", tags=["sources"], response_model=SourceModel) async def create_source( @@ -100,13 +105,13 @@ async def create_source( @router.delete("/sources/{source_id}", tags=["sources"]) async def delete_source( - source_id, + source_id: uuid.UUID, user_id: uuid.UUID = Depends(get_current_user_with_server), ): """Delete a data source.""" interface.clear() try: - server.delete_source(source_id=uuid.UUID(source_id), user_id=user_id) + server.delete_source(source_id=source_id, user_id=user_id) return JSONResponse(status_code=status.HTTP_200_OK, content={"message": f"Source source_id={source_id} successfully deleted"}) except HTTPException: raise diff --git a/memgpt/server/server.py b/memgpt/server/server.py index 148718a6c3..e424c9a1ff 100644 --- a/memgpt/server/server.py +++ b/memgpt/server/server.py @@ -37,7 +37,7 @@ Preset, ) -from memgpt.models.pydantic_models import SourceModel, PassageModel, DocumentModel +from memgpt.models.pydantic_models import SourceModel, PassageModel, DocumentModel, PresetModel from memgpt.interface import AgentInterface # abstract # TODO use custom interface @@ -751,13 +751,27 @@ def delete_agent( if agent is not None: self.ms.delete_agent(agent_id=agent_id) + def delete_preset(self, user_id: uuid.UUID, preset_id: uuid.UUID) -> Preset: + if self.ms.get_user(user_id=user_id) is None: + raise ValueError(f"User user_id={user_id} does not exist") + + # first get the preset by name + preset = self.get_preset(preset_id=preset_id, user_id=user_id) + if preset is None: + raise ValueError(f"Could not find preset_id {preset_id}") + # then delete via name + # TODO allow delete-by-id, eg via server.delete_preset function + self.ms.delete_preset(name=preset.name, user_id=user_id) + + return preset + def initialize_default_presets(self, user_id: uuid.UUID): """Add default preset options into the metadata store""" presets.add_default_presets(user_id, self.ms) def create_preset(self, preset: Preset): """Create a new preset using a config""" - if self.ms.get_user(user_id=preset.user_id) is None: + if preset.user_id is not None and self.ms.get_user(user_id=preset.user_id) is None: raise ValueError(f"User user_id={preset.user_id} does not exist") self.ms.create_preset(preset) @@ -769,6 +783,13 @@ def get_preset( """Get the preset""" return self.ms.get_preset(preset_id=preset_id, name=preset_name, user_id=user_id) + def list_presets(self, user_id: uuid.UUID) -> List[PresetModel]: + # TODO update once we strip Preset in favor of PresetModel + presets = self.ms.list_presets(user_id=user_id) + presets = [PresetModel(**vars(p)) for p in presets] + + return presets + def _agent_state_to_config(self, agent_state: AgentState) -> dict: """Convert AgentState to a dict for a JSON response""" assert agent_state is not None diff --git a/memgpt/utils.py b/memgpt/utils.py index 3527e17e0c..32451592dc 100644 --- a/memgpt/utils.py +++ b/memgpt/utils.py @@ -33,6 +33,7 @@ CORE_MEMORY_HUMAN_CHAR_LIMIT, CORE_MEMORY_PERSONA_CHAR_LIMIT, JSON_ENSURE_ASCII, + TOOL_CALL_ID_MAX_LEN, ) from memgpt.models.chat_completion_response import ChatCompletionResponse @@ -469,7 +470,7 @@ def get_tool_call_id() -> str: - return str(uuid.uuid4()) + return str(uuid.uuid4())[:TOOL_CALL_ID_MAX_LEN] def assistant_function_to_tool(assistant_message: dict) -> dict: diff --git a/tests/test_client.py b/tests/test_client.py index 6c25091e0f..7e67759d86 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -7,6 +7,7 @@ from memgpt.server.rest_api.server import start_server from memgpt import Admin, create_client from memgpt.constants import DEFAULT_PRESET +from memgpt.data_types import Preset # TODO move to PresetModel from dotenv import load_dotenv from tests.config import TestMGPTConfig @@ -287,3 +288,54 @@ def test_sources(client, agent): # delete the source client.delete_source(source.id) + + +def test_presets(client, agent): + + new_preset = Preset( + # user_id=client.user_id, + name="pytest_test_preset", + description="DUMMY_DESCRIPTION", + system="DUMMY_SYSTEM", + persona="DUMMY_PERSONA", + persona_name="DUMMY_PERSONA_NAME", + human="DUMMY_HUMAN", + human_name="DUMMY_HUMAN_NAME", + functions_schema=[ + { + "name": "send_message", + "json_schema": { + "name": "send_message", + "description": "Sends a message to the human user.", + "parameters": { + "type": "object", + "properties": { + "message": {"type": "string", "description": "Message contents. All unicode (including emojis) are supported."} + }, + "required": ["message"], + }, + }, + "tags": ["memgpt-base"], + "source_type": "python", + "source_code": 'def send_message(self, message: str) -> Optional[str]:\n """\n Sends a message to the human user.\n\n Args:\n message (str): Message contents. All unicode (including emojis) are supported.\n\n Returns:\n Optional[str]: None is always returned as this function does not produce a response.\n """\n self.interface.assistant_message(message)\n return None\n', + } + ], + ) + + # List all presets and make sure the preset is NOT in the list + all_presets = client.list_presets() + assert new_preset.id not in [p.id for p in all_presets], (new_preset, all_presets) + + # Create a preset + client.create_preset(preset=new_preset) + + # List all presets and make sure the preset is in the list + all_presets = client.list_presets() + assert new_preset.id in [p.id for p in all_presets], (new_preset, all_presets) + + # Delete the preset + client.delete_preset(preset_id=new_preset.id) + + # List all presets and make sure the preset is NOT in the list + all_presets = client.list_presets() + assert new_preset.id not in [p.id for p in all_presets], (new_preset, all_presets)