Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add GET REST API route for listing tools #1100

Merged
merged 4 commits into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion memgpt/metadata.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
""" Metadata store for user/agent/data_source information"""

import os
import inspect as python_inspect
import uuid
import secrets
from typing import Optional, List
Expand All @@ -9,8 +10,9 @@
from memgpt.utils import get_local_time, enforce_types
from memgpt.data_types import AgentState, Source, User, LLMConfig, EmbeddingConfig, Token, Preset
from memgpt.config import MemGPTConfig
from memgpt.functions.functions import load_all_function_sets

from memgpt.models.pydantic_models import PersonaModel, HumanModel
from memgpt.models.pydantic_models import PersonaModel, HumanModel, ToolModel

from sqlalchemy import create_engine, Column, String, BIGINT, select, inspect, text, JSON, BLOB, BINARY, ARRAY, Boolean
from sqlalchemy import func
Expand Down Expand Up @@ -517,6 +519,25 @@ def list_presets(self, user_id: uuid.UUID) -> List[Preset]:
results = session.query(PresetModel).filter(PresetModel.user_id == user_id).all()
return [r.to_record() for r in results]

@enforce_types
def list_tools(self, user_id: uuid.UUID) -> List[ToolModel]:
with self.session_maker() as session:
available_functions = load_all_function_sets()
print(available_functions)
results = [
ToolModel(
name=k,
json_schema=v["json_schema"],
source_type="python",
source_code=python_inspect.getsource(v["python_function"]),
)
for k, v in available_functions.items()
]
print(results)
return results
# results = session.query(PresetModel).filter(PresetModel.user_id == user_id).all()
# return [r.to_record() for r in results]

@enforce_types
def list_agents(self, user_id: uuid.UUID) -> List[AgentState]:
with self.session_maker() as session:
Expand Down
10 changes: 9 additions & 1 deletion memgpt/models/pydantic_models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional, Dict
from typing import List, Optional, Dict, Literal
from pydantic import BaseModel, Field, Json
import uuid
from datetime import datetime
Expand Down Expand Up @@ -36,6 +36,14 @@ class PresetModel(BaseModel):
functions_schema: List[Dict] = Field(..., description="The functions schema of the preset.")


class ToolModel(BaseModel):
# TODO move into database
name: str = Field(..., description="The name of the function.")
json_schema: dict = Field(..., description="The JSON schema of the function.")
source_type: Optional[Literal["python"]] = Field(None, description="The type of the source code.")
source_code: Optional[str] = Field(..., description="The source code of the function.")


class AgentStateModel(BaseModel):
id: uuid.UUID = Field(..., description="The unique identifier of the agent.")
name: str = Field(..., description="The name of the agent.")
Expand Down
2 changes: 2 additions & 0 deletions memgpt/server/rest_api/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from memgpt.server.rest_api.openai_assistants.assistants import setup_openai_assistant_router
from memgpt.server.rest_api.personas.index import setup_personas_index_router
from memgpt.server.rest_api.static_files import mount_static_files
from memgpt.server.rest_api.tools.index import setup_tools_index_router
from memgpt.server.server import SyncServer

"""
Expand Down Expand Up @@ -92,6 +93,7 @@ def verify_password(credentials: HTTPAuthorizationCredentials = Depends(security
app.include_router(setup_humans_index_router(server, interface, password), prefix=API_PREFIX)
app.include_router(setup_personas_index_router(server, interface, password), prefix=API_PREFIX)
app.include_router(setup_models_index_router(server, interface, password), prefix=API_PREFIX)
app.include_router(setup_tools_index_router(server, interface, password), prefix=API_PREFIX)

# /api/config endpoints
app.include_router(setup_config_index_router(server, interface, password), prefix=API_PREFIX)
Expand Down
Empty file.
35 changes: 35 additions & 0 deletions memgpt/server/rest_api/tools/index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import uuid
from functools import partial
from typing import List

from fastapi import APIRouter, Depends, Body
from pydantic import BaseModel, Field

from memgpt.models.pydantic_models import ToolModel
from memgpt.server.rest_api.auth_token import get_current_user
from memgpt.server.rest_api.interface import QueuingInterface
from memgpt.server.server import SyncServer

router = APIRouter()


class ListToolsResponse(BaseModel):
tools: List[ToolModel] = Field(..., description="List of tools (functions).")


def setup_tools_index_router(server: SyncServer, interface: QueuingInterface, password: str):
get_current_user_with_server = partial(partial(get_current_user, server), password)

@router.get("/tools", tags=["tools"], response_model=ListToolsResponse)
async def list_tools(
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""
Get a list of all tools available to agents created by a user
"""
# Clear the interface
interface.clear()
tools = server.ms.list_tools(user_id=user_id)
return ListToolsResponse(tools=tools)

return router
Loading