Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: remove get_current_user and replace with direct header read #1834

Merged
merged 3 commits into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion letta/server/rest_api/admin/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ class CreateToolResponse(BaseModel):


def setup_tools_index_router(server: SyncServer, interface: QueuingInterface):
# get_current_user_with_server = partial(partial(get_current_user, server), password)

@router.delete("/tools/{tool_name}", tags=["tools"])
async def delete_tool(
Expand Down
18 changes: 1 addition & 17 deletions letta/server/rest_api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
from typing import Optional

import uvicorn
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from fastapi import FastAPI
from starlette.middleware.cors import CORSMiddleware

from letta.server.constants import REST_DEFAULT_PORT
Expand Down Expand Up @@ -84,21 +83,6 @@ def create_application() -> "FastAPI":
allow_headers=["*"],
)

@app.middleware("http")
async def set_current_user_middleware(request: Request, call_next):
user_id = request.headers.get("user_id")
if user_id:
try:
server.set_current_user(user_id)
except ValueError as e:
# Return an HTTP 401 Unauthorized response
# raise HTTPException(status_code=401, detail=str(e))
return JSONResponse(status_code=401, content={"detail": str(e)})
else:
server.set_current_user(None)
response = await call_next(request)
return response

for route in v1_routes:
app.include_router(route, prefix=API_PREFIX)
# this gives undocumented routes for "latest" and bare api calls.
Expand Down
15 changes: 9 additions & 6 deletions letta/server/rest_api/routers/openai/assistants/threads.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import uuid
from typing import TYPE_CHECKING, List

from fastapi import APIRouter, Body, Depends, HTTPException, Path, Query
from fastapi import APIRouter, Body, Depends, Header, HTTPException, Path, Query

from letta.constants import DEFAULT_PRESET
from letta.schemas.agent import CreateAgent
Expand Down Expand Up @@ -43,11 +43,12 @@
def create_thread(
request: CreateThreadRequest = Body(...),
server: SyncServer = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
):
# TODO: use requests.description and requests.metadata fields
# TODO: handle requests.file_ids and requests.tools
# TODO: eventually allow request to override embedding/llm model
actor = server.get_current_user()
actor = server.get_user_or_default(user_id=user_id)

print("Create thread/agent", request)
# create a letta agent
Expand All @@ -67,8 +68,9 @@ def create_thread(
def retrieve_thread(
thread_id: str = Path(..., description="The unique identifier of the thread."),
server: SyncServer = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
):
actor = server.get_current_user()
actor = server.get_user_or_default(user_id=user_id)
agent = server.get_agent(user_id=actor.id, agent_id=thread_id)
assert agent is not None
return OpenAIThread(
Expand Down Expand Up @@ -100,8 +102,9 @@ def create_message(
thread_id: str = Path(..., description="The unique identifier of the thread."),
request: CreateMessageRequest = Body(...),
server: SyncServer = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
):
actor = server.get_current_user()
actor = server.get_user_or_default(user_id=user_id)
agent_id = thread_id
# create message object
message = Message(
Expand Down Expand Up @@ -143,8 +146,9 @@ def list_messages(
after: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
before: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
server: SyncServer = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
):
actor = server.get_current_user()
actor = server.get_user_or_default(user_id)
after_uuid = after if before else None
before_uuid = before if before else None
agent_id = thread_id
Expand Down Expand Up @@ -239,7 +243,6 @@ def create_run(
request: CreateRunRequest = Body(...),
server: SyncServer = Depends(get_letta_server),
):
server.get_current_user()

# TODO: add request.instructions as a message?
agent_id = thread_id
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
from typing import TYPE_CHECKING

from fastapi import APIRouter, Body, Depends, HTTPException
from fastapi import APIRouter, Body, Depends, Header, HTTPException

from letta.schemas.enums import MessageRole
from letta.schemas.letta_message import FunctionCall, LettaMessage
Expand Down Expand Up @@ -30,12 +30,14 @@
async def create_chat_completion(
completion_request: ChatCompletionRequest = Body(...),
server: "SyncServer" = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
):
"""Send a message to a Letta agent via a /chat/completions completion_request
The bearer token will be used to identify the user.
The 'user' field in the completion_request should be set to the agent ID.
"""
actor = server.get_current_user()
actor = server.get_user_or_default(user_id=user_id)

agent_id = completion_request.user
if agent_id is None:
raise HTTPException(status_code=400, detail="Must pass agent_id in the 'user' field")
Expand Down
36 changes: 23 additions & 13 deletions letta/server/rest_api/routers/v1/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from datetime import datetime
from typing import Dict, List, Optional, Union

from fastapi import APIRouter, Body, Depends, HTTPException, Query, status
from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query, status
from fastapi.responses import JSONResponse, StreamingResponse
from starlette.responses import StreamingResponse

Expand Down Expand Up @@ -40,12 +40,13 @@
@router.get("/", response_model=List[AgentState], operation_id="list_agents")
def list_agents(
server: "SyncServer" = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
):
"""
List all agents associated with a given user.
This endpoint retrieves a list of all agents and their configurations associated with the specified user ID.
"""
actor = server.get_current_user()
actor = server.get_user_or_default(user_id=user_id)

return server.list_agents(user_id=actor.id)

Expand All @@ -54,11 +55,12 @@ def list_agents(
def create_agent(
agent: CreateAgent = Body(...),
server: "SyncServer" = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
):
"""
Create a new agent with the specified configuration.
"""
actor = server.get_current_user()
actor = server.get_user_or_default(user_id=user_id)
agent.user_id = actor.id
# TODO: sarah make general
# TODO: eventually remove this
Expand All @@ -74,9 +76,10 @@ def update_agent(
agent_id: str,
update_agent: UpdateAgentState = Body(...),
server: "SyncServer" = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
):
"""Update an exsiting agent"""
actor = server.get_current_user()
actor = server.get_user_or_default(user_id=user_id)

update_agent.id = agent_id
return server.update_agent(update_agent, user_id=actor.id)
Expand All @@ -86,11 +89,12 @@ def update_agent(
def get_agent_state(
agent_id: str,
server: "SyncServer" = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
):
"""
Get the state of the agent.
"""
actor = server.get_current_user()
actor = server.get_user_or_default(user_id=user_id)

if not server.ms.get_agent(user_id=actor.id, agent_id=agent_id):
# agent does not exist
Expand All @@ -103,11 +107,12 @@ def get_agent_state(
def delete_agent(
agent_id: str,
server: "SyncServer" = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
):
"""
Delete an agent.
"""
actor = server.get_current_user()
actor = server.get_user_or_default(user_id=user_id)

return server.delete_agent(user_id=actor.id, agent_id=agent_id)

Expand All @@ -120,7 +125,6 @@ def get_agent_sources(
"""
Get the sources associated with an agent.
"""
server.get_current_user()

return server.list_attached_sources(agent_id)

Expand Down Expand Up @@ -155,12 +159,13 @@ def update_agent_memory(
agent_id: str,
request: Dict = Body(...),
server: "SyncServer" = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
):
"""
Update the core memory of a specific agent.
This endpoint accepts new memory contents (human and persona) and updates the core memory of the agent identified by the user ID and agent ID.
"""
actor = server.get_current_user()
actor = server.get_user_or_default(user_id=user_id)

memory = server.update_agent_core_memory(user_id=actor.id, agent_id=agent_id, new_memory_contents=request)
return memory
Expand Down Expand Up @@ -197,11 +202,12 @@ def get_agent_archival_memory(
after: Optional[int] = Query(None, description="Unique ID of the memory to start the query range at."),
before: Optional[int] = Query(None, description="Unique ID of the memory to end the query range at."),
limit: Optional[int] = Query(None, description="How many results to include in the response."),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
):
"""
Retrieve the memories in an agent's archival memory store (paginated query).
"""
actor = server.get_current_user()
actor = server.get_user_or_default(user_id=user_id)

# TODO need to add support for non-postgres here
# chroma will throw:
Expand All @@ -221,11 +227,12 @@ def insert_agent_archival_memory(
agent_id: str,
request: CreateArchivalMemory = Body(...),
server: "SyncServer" = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
):
"""
Insert a memory into an agent's archival memory store.
"""
actor = server.get_current_user()
actor = server.get_user_or_default(user_id=user_id)

return server.insert_archival_memory(user_id=actor.id, agent_id=agent_id, memory_contents=request.text)

Expand All @@ -238,11 +245,12 @@ def delete_agent_archival_memory(
memory_id: str,
# memory_id: str = Query(..., description="Unique ID of the memory to be deleted."),
server: "SyncServer" = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
):
"""
Delete a memory from an agent's archival memory store.
"""
actor = server.get_current_user()
actor = server.get_user_or_default(user_id=user_id)

server.delete_archival_memory(user_id=actor.id, agent_id=agent_id, memory_id=memory_id)
return JSONResponse(status_code=status.HTTP_200_OK, content={"message": f"Memory id={memory_id} successfully deleted"})
Expand All @@ -268,11 +276,12 @@ def get_agent_messages(
DEFAULT_MESSAGE_TOOL_KWARG,
description="[Only applicable if use_assistant_message is True] The name of the message argument in the designated message tool.",
),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
):
"""
Retrieve message history for an agent.
"""
actor = server.get_current_user()
actor = server.get_user_or_default(user_id=user_id)

return server.get_agent_recall_cursor(
user_id=actor.id,
Expand Down Expand Up @@ -306,13 +315,14 @@ async def send_message(
agent_id: str,
server: SyncServer = Depends(get_letta_server),
request: LettaRequest = Body(...),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
):
"""
Process a user message and return the agent's response.
This endpoint accepts a message from a user and processes it through the agent.
It can optionally stream the response if 'stream_steps' or 'stream_tokens' is set to True.
"""
actor = server.get_current_user()
actor = server.get_user_or_default(user_id=user_id)

# TODO(charles): support sending multiple messages
assert len(request.messages) == 1, f"Multiple messages not supported: {request.messages}"
Expand Down
8 changes: 5 additions & 3 deletions letta/server/rest_api/routers/v1/blocks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import TYPE_CHECKING, List, Optional

from fastapi import APIRouter, Body, Depends, HTTPException, Query
from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query

from letta.schemas.block import Block, CreateBlock, UpdateBlock
from letta.server.rest_api.utils import get_letta_server
Expand All @@ -19,8 +19,9 @@ def list_blocks(
templates_only: bool = Query(True, description="Whether to include only templates"),
name: Optional[str] = Query(None, description="Name of the block"),
server: SyncServer = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
):
actor = server.get_current_user()
actor = server.get_user_or_default(user_id=user_id)

blocks = server.get_blocks(user_id=actor.id, label=label, template=templates_only, name=name)
if blocks is None:
Expand All @@ -32,8 +33,9 @@ def list_blocks(
def create_block(
create_block: CreateBlock = Body(...),
server: SyncServer = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
):
actor = server.get_current_user()
actor = server.get_user_or_default(user_id=user_id)

create_block.user_id = actor.id
return server.create_block(user_id=actor.id, request=create_block)
Expand Down
8 changes: 5 additions & 3 deletions letta/server/rest_api/routers/v1/jobs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import List, Optional

from fastapi import APIRouter, Depends, Query
from fastapi import APIRouter, Depends, Header, Query

from letta.schemas.job import Job
from letta.server.rest_api.utils import get_letta_server
Expand All @@ -13,11 +13,12 @@
def list_jobs(
server: "SyncServer" = Depends(get_letta_server),
source_id: Optional[str] = Query(None, description="Only list jobs associated with the source."),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
):
"""
List all jobs.
"""
actor = server.get_current_user()
actor = server.get_user_or_default(user_id=user_id)

# TODO: add filtering by status
jobs = server.list_jobs(user_id=actor.id)
Expand All @@ -33,11 +34,12 @@ def list_jobs(
@router.get("/active", response_model=List[Job], operation_id="list_active_jobs")
def list_active_jobs(
server: "SyncServer" = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
):
"""
List all active jobs.
"""
actor = server.get_current_user()
actor = server.get_user_or_default(user_id=user_id)

return server.list_active_jobs(user_id=actor.id)

Expand Down
Loading
Loading