diff --git a/letta/server/rest_api/admin/tools.py b/letta/server/rest_api/admin/tools.py index 3c451ea012..e857b79ec3 100644 --- a/letta/server/rest_api/admin/tools.py +++ b/letta/server/rest_api/admin/tools.py @@ -26,7 +26,6 @@ class CreateToolResponse(BaseModel): def setup_tools_index_router(server: SyncServer, interface: QueuingInterface): - # get_current_user_with_server = partial(partial(get_current_user, server), password) @router.delete("/tools/{tool_name}", tags=["tools"]) async def delete_tool( diff --git a/letta/server/rest_api/app.py b/letta/server/rest_api/app.py index 1b3fe88a55..7b73674e23 100644 --- a/letta/server/rest_api/app.py +++ b/letta/server/rest_api/app.py @@ -5,8 +5,7 @@ from typing import Optional import uvicorn -from fastapi import FastAPI, Request -from fastapi.responses import JSONResponse +from fastapi import FastAPI from starlette.middleware.cors import CORSMiddleware from letta.server.constants import REST_DEFAULT_PORT @@ -84,21 +83,6 @@ def create_application() -> "FastAPI": allow_headers=["*"], ) - @app.middleware("http") - async def set_current_user_middleware(request: Request, call_next): - user_id = request.headers.get("user_id") - if user_id: - try: - server.set_current_user(user_id) - except ValueError as e: - # Return an HTTP 401 Unauthorized response - # raise HTTPException(status_code=401, detail=str(e)) - return JSONResponse(status_code=401, content={"detail": str(e)}) - else: - server.set_current_user(None) - response = await call_next(request) - return response - for route in v1_routes: app.include_router(route, prefix=API_PREFIX) # this gives undocumented routes for "latest" and bare api calls. diff --git a/letta/server/rest_api/routers/openai/assistants/threads.py b/letta/server/rest_api/routers/openai/assistants/threads.py index 108b14978f..1e9be774ae 100644 --- a/letta/server/rest_api/routers/openai/assistants/threads.py +++ b/letta/server/rest_api/routers/openai/assistants/threads.py @@ -1,7 +1,7 @@ import uuid from typing import TYPE_CHECKING, List -from fastapi import APIRouter, Body, Depends, HTTPException, Path, Query +from fastapi import APIRouter, Body, Depends, Header, HTTPException, Path, Query from letta.constants import DEFAULT_PRESET from letta.schemas.agent import CreateAgent @@ -43,11 +43,12 @@ def create_thread( request: CreateThreadRequest = Body(...), server: SyncServer = Depends(get_letta_server), + user_id: str = Header(None), # Extract user_id from header, default to None if not present ): # TODO: use requests.description and requests.metadata fields # TODO: handle requests.file_ids and requests.tools # TODO: eventually allow request to override embedding/llm model - actor = server.get_current_user() + actor = server.get_user_or_default(user_id=user_id) print("Create thread/agent", request) # create a letta agent @@ -67,8 +68,9 @@ def create_thread( def retrieve_thread( thread_id: str = Path(..., description="The unique identifier of the thread."), server: SyncServer = Depends(get_letta_server), + user_id: str = Header(None), # Extract user_id from header, default to None if not present ): - actor = server.get_current_user() + actor = server.get_user_or_default(user_id=user_id) agent = server.get_agent(user_id=actor.id, agent_id=thread_id) assert agent is not None return OpenAIThread( @@ -100,8 +102,9 @@ def create_message( thread_id: str = Path(..., description="The unique identifier of the thread."), request: CreateMessageRequest = Body(...), server: SyncServer = Depends(get_letta_server), + user_id: str = Header(None), # Extract user_id from header, default to None if not present ): - actor = server.get_current_user() + actor = server.get_user_or_default(user_id=user_id) agent_id = thread_id # create message object message = Message( @@ -143,8 +146,9 @@ def list_messages( after: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."), before: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."), server: SyncServer = Depends(get_letta_server), + user_id: str = Header(None), # Extract user_id from header, default to None if not present ): - actor = server.get_current_user() + actor = server.get_user_or_default(user_id) after_uuid = after if before else None before_uuid = before if before else None agent_id = thread_id @@ -239,7 +243,6 @@ def create_run( request: CreateRunRequest = Body(...), server: SyncServer = Depends(get_letta_server), ): - server.get_current_user() # TODO: add request.instructions as a message? agent_id = thread_id diff --git a/letta/server/rest_api/routers/openai/chat_completions/chat_completions.py b/letta/server/rest_api/routers/openai/chat_completions/chat_completions.py index 368990a05f..489abe01c6 100644 --- a/letta/server/rest_api/routers/openai/chat_completions/chat_completions.py +++ b/letta/server/rest_api/routers/openai/chat_completions/chat_completions.py @@ -1,7 +1,7 @@ import json from typing import TYPE_CHECKING -from fastapi import APIRouter, Body, Depends, HTTPException +from fastapi import APIRouter, Body, Depends, Header, HTTPException from letta.schemas.enums import MessageRole from letta.schemas.letta_message import FunctionCall, LettaMessage @@ -30,12 +30,14 @@ async def create_chat_completion( completion_request: ChatCompletionRequest = Body(...), server: "SyncServer" = Depends(get_letta_server), + user_id: str = Header(None), # Extract user_id from header, default to None if not present ): """Send a message to a Letta agent via a /chat/completions completion_request The bearer token will be used to identify the user. The 'user' field in the completion_request should be set to the agent ID. """ - actor = server.get_current_user() + actor = server.get_user_or_default(user_id=user_id) + agent_id = completion_request.user if agent_id is None: raise HTTPException(status_code=400, detail="Must pass agent_id in the 'user' field") diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index cf4a8a6411..09c2c69764 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -2,7 +2,7 @@ from datetime import datetime from typing import Dict, List, Optional, Union -from fastapi import APIRouter, Body, Depends, HTTPException, Query, status +from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query, status from fastapi.responses import JSONResponse, StreamingResponse from starlette.responses import StreamingResponse @@ -40,12 +40,13 @@ @router.get("/", response_model=List[AgentState], operation_id="list_agents") def list_agents( server: "SyncServer" = Depends(get_letta_server), + user_id: str = Header(None), # Extract user_id from header, default to None if not present ): """ List all agents associated with a given user. This endpoint retrieves a list of all agents and their configurations associated with the specified user ID. """ - actor = server.get_current_user() + actor = server.get_user_or_default(user_id=user_id) return server.list_agents(user_id=actor.id) @@ -54,11 +55,12 @@ def list_agents( def create_agent( agent: CreateAgent = Body(...), server: "SyncServer" = Depends(get_letta_server), + user_id: str = Header(None), # Extract user_id from header, default to None if not present ): """ Create a new agent with the specified configuration. """ - actor = server.get_current_user() + actor = server.get_user_or_default(user_id=user_id) agent.user_id = actor.id # TODO: sarah make general # TODO: eventually remove this @@ -74,9 +76,10 @@ def update_agent( agent_id: str, update_agent: UpdateAgentState = Body(...), server: "SyncServer" = Depends(get_letta_server), + user_id: str = Header(None), # Extract user_id from header, default to None if not present ): """Update an exsiting agent""" - actor = server.get_current_user() + actor = server.get_user_or_default(user_id=user_id) update_agent.id = agent_id return server.update_agent(update_agent, user_id=actor.id) @@ -86,11 +89,12 @@ def update_agent( def get_agent_state( agent_id: str, server: "SyncServer" = Depends(get_letta_server), + user_id: str = Header(None), # Extract user_id from header, default to None if not present ): """ Get the state of the agent. """ - actor = server.get_current_user() + actor = server.get_user_or_default(user_id=user_id) if not server.ms.get_agent(user_id=actor.id, agent_id=agent_id): # agent does not exist @@ -103,11 +107,12 @@ def get_agent_state( def delete_agent( agent_id: str, server: "SyncServer" = Depends(get_letta_server), + user_id: str = Header(None), # Extract user_id from header, default to None if not present ): """ Delete an agent. """ - actor = server.get_current_user() + actor = server.get_user_or_default(user_id=user_id) return server.delete_agent(user_id=actor.id, agent_id=agent_id) @@ -120,7 +125,6 @@ def get_agent_sources( """ Get the sources associated with an agent. """ - server.get_current_user() return server.list_attached_sources(agent_id) @@ -155,12 +159,13 @@ def update_agent_memory( agent_id: str, request: Dict = Body(...), server: "SyncServer" = Depends(get_letta_server), + user_id: str = Header(None), # Extract user_id from header, default to None if not present ): """ Update the core memory of a specific agent. This endpoint accepts new memory contents (human and persona) and updates the core memory of the agent identified by the user ID and agent ID. """ - actor = server.get_current_user() + actor = server.get_user_or_default(user_id=user_id) memory = server.update_agent_core_memory(user_id=actor.id, agent_id=agent_id, new_memory_contents=request) return memory @@ -197,11 +202,12 @@ def get_agent_archival_memory( after: Optional[int] = Query(None, description="Unique ID of the memory to start the query range at."), before: Optional[int] = Query(None, description="Unique ID of the memory to end the query range at."), limit: Optional[int] = Query(None, description="How many results to include in the response."), + user_id: str = Header(None), # Extract user_id from header, default to None if not present ): """ Retrieve the memories in an agent's archival memory store (paginated query). """ - actor = server.get_current_user() + actor = server.get_user_or_default(user_id=user_id) # TODO need to add support for non-postgres here # chroma will throw: @@ -221,11 +227,12 @@ def insert_agent_archival_memory( agent_id: str, request: CreateArchivalMemory = Body(...), server: "SyncServer" = Depends(get_letta_server), + user_id: str = Header(None), # Extract user_id from header, default to None if not present ): """ Insert a memory into an agent's archival memory store. """ - actor = server.get_current_user() + actor = server.get_user_or_default(user_id=user_id) return server.insert_archival_memory(user_id=actor.id, agent_id=agent_id, memory_contents=request.text) @@ -238,11 +245,12 @@ def delete_agent_archival_memory( memory_id: str, # memory_id: str = Query(..., description="Unique ID of the memory to be deleted."), server: "SyncServer" = Depends(get_letta_server), + user_id: str = Header(None), # Extract user_id from header, default to None if not present ): """ Delete a memory from an agent's archival memory store. """ - actor = server.get_current_user() + actor = server.get_user_or_default(user_id=user_id) server.delete_archival_memory(user_id=actor.id, agent_id=agent_id, memory_id=memory_id) return JSONResponse(status_code=status.HTTP_200_OK, content={"message": f"Memory id={memory_id} successfully deleted"}) @@ -268,11 +276,12 @@ def get_agent_messages( DEFAULT_MESSAGE_TOOL_KWARG, description="[Only applicable if use_assistant_message is True] The name of the message argument in the designated message tool.", ), + user_id: str = Header(None), # Extract user_id from header, default to None if not present ): """ Retrieve message history for an agent. """ - actor = server.get_current_user() + actor = server.get_user_or_default(user_id=user_id) return server.get_agent_recall_cursor( user_id=actor.id, @@ -306,13 +315,14 @@ async def send_message( agent_id: str, server: SyncServer = Depends(get_letta_server), request: LettaRequest = Body(...), + user_id: str = Header(None), # Extract user_id from header, default to None if not present ): """ Process a user message and return the agent's response. This endpoint accepts a message from a user and processes it through the agent. It can optionally stream the response if 'stream_steps' or 'stream_tokens' is set to True. """ - actor = server.get_current_user() + actor = server.get_user_or_default(user_id=user_id) # TODO(charles): support sending multiple messages assert len(request.messages) == 1, f"Multiple messages not supported: {request.messages}" diff --git a/letta/server/rest_api/routers/v1/blocks.py b/letta/server/rest_api/routers/v1/blocks.py index 1341b8aac1..df9c86cc85 100644 --- a/letta/server/rest_api/routers/v1/blocks.py +++ b/letta/server/rest_api/routers/v1/blocks.py @@ -1,6 +1,6 @@ from typing import TYPE_CHECKING, List, Optional -from fastapi import APIRouter, Body, Depends, HTTPException, Query +from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query from letta.schemas.block import Block, CreateBlock, UpdateBlock from letta.server.rest_api.utils import get_letta_server @@ -19,8 +19,9 @@ def list_blocks( templates_only: bool = Query(True, description="Whether to include only templates"), name: Optional[str] = Query(None, description="Name of the block"), server: SyncServer = Depends(get_letta_server), + user_id: str = Header(None), # Extract user_id from header, default to None if not present ): - actor = server.get_current_user() + actor = server.get_user_or_default(user_id=user_id) blocks = server.get_blocks(user_id=actor.id, label=label, template=templates_only, name=name) if blocks is None: @@ -32,8 +33,9 @@ def list_blocks( def create_block( create_block: CreateBlock = Body(...), server: SyncServer = Depends(get_letta_server), + user_id: str = Header(None), # Extract user_id from header, default to None if not present ): - actor = server.get_current_user() + actor = server.get_user_or_default(user_id=user_id) create_block.user_id = actor.id return server.create_block(user_id=actor.id, request=create_block) diff --git a/letta/server/rest_api/routers/v1/jobs.py b/letta/server/rest_api/routers/v1/jobs.py index 9064052f0d..113ea81d66 100644 --- a/letta/server/rest_api/routers/v1/jobs.py +++ b/letta/server/rest_api/routers/v1/jobs.py @@ -1,6 +1,6 @@ from typing import List, Optional -from fastapi import APIRouter, Depends, Query +from fastapi import APIRouter, Depends, Header, Query from letta.schemas.job import Job from letta.server.rest_api.utils import get_letta_server @@ -13,11 +13,12 @@ def list_jobs( server: "SyncServer" = Depends(get_letta_server), source_id: Optional[str] = Query(None, description="Only list jobs associated with the source."), + user_id: str = Header(None), # Extract user_id from header, default to None if not present ): """ List all jobs. """ - actor = server.get_current_user() + actor = server.get_user_or_default(user_id=user_id) # TODO: add filtering by status jobs = server.list_jobs(user_id=actor.id) @@ -33,11 +34,12 @@ def list_jobs( @router.get("/active", response_model=List[Job], operation_id="list_active_jobs") def list_active_jobs( server: "SyncServer" = Depends(get_letta_server), + user_id: str = Header(None), # Extract user_id from header, default to None if not present ): """ List all active jobs. """ - actor = server.get_current_user() + actor = server.get_user_or_default(user_id=user_id) return server.list_active_jobs(user_id=actor.id) diff --git a/letta/server/rest_api/routers/v1/sources.py b/letta/server/rest_api/routers/v1/sources.py index cb6d9e7d72..2feae2df9b 100644 --- a/letta/server/rest_api/routers/v1/sources.py +++ b/letta/server/rest_api/routers/v1/sources.py @@ -2,7 +2,7 @@ import tempfile from typing import List -from fastapi import APIRouter, BackgroundTasks, Depends, Query, UploadFile +from fastapi import APIRouter, BackgroundTasks, Depends, Header, Query, UploadFile from letta.schemas.document import Document from letta.schemas.job import Job @@ -21,11 +21,12 @@ def get_source( source_id: str, server: "SyncServer" = Depends(get_letta_server), + user_id: str = Header(None), # Extract user_id from header, default to None if not present ): """ Get all sources """ - actor = server.get_current_user() + actor = server.get_user_or_default(user_id=user_id) return server.get_source(source_id=source_id, user_id=actor.id) @@ -34,11 +35,12 @@ def get_source( def get_source_id_by_name( source_name: str, server: "SyncServer" = Depends(get_letta_server), + user_id: str = Header(None), # Extract user_id from header, default to None if not present ): """ Get a source by name """ - actor = server.get_current_user() + actor = server.get_user_or_default(user_id=user_id) source_id = server.get_source_id(source_name=source_name, user_id=actor.id) return source_id @@ -47,11 +49,12 @@ def get_source_id_by_name( @router.get("/", response_model=List[Source], operation_id="list_sources") def list_sources( server: "SyncServer" = Depends(get_letta_server), + user_id: str = Header(None), # Extract user_id from header, default to None if not present ): """ List all data sources created by a user. """ - actor = server.get_current_user() + actor = server.get_user_or_default(user_id=user_id) return server.list_all_sources(user_id=actor.id) @@ -60,11 +63,12 @@ def list_sources( def create_source( source: SourceCreate, server: "SyncServer" = Depends(get_letta_server), + user_id: str = Header(None), # Extract user_id from header, default to None if not present ): """ Create a new data source. """ - actor = server.get_current_user() + actor = server.get_user_or_default(user_id=user_id) return server.create_source(request=source, user_id=actor.id) @@ -74,11 +78,13 @@ def update_source( source_id: str, source: SourceUpdate, server: "SyncServer" = Depends(get_letta_server), + user_id: str = Header(None), # Extract user_id from header, default to None if not present ): """ Update the name or documentation of an existing data source. """ - actor = server.get_current_user() + actor = server.get_user_or_default(user_id=user_id) + assert source.id == source_id, "Source ID in path must match ID in request body" return server.update_source(request=source, user_id=actor.id) @@ -88,11 +94,12 @@ def update_source( def delete_source( source_id: str, server: "SyncServer" = Depends(get_letta_server), + user_id: str = Header(None), # Extract user_id from header, default to None if not present ): """ Delete a data source. """ - actor = server.get_current_user() + actor = server.get_user_or_default(user_id=user_id) server.delete_source(source_id=source_id, user_id=actor.id) @@ -102,11 +109,12 @@ def attach_source_to_agent( source_id: str, agent_id: str = Query(..., description="The unique identifier of the agent to attach the source to."), server: "SyncServer" = Depends(get_letta_server), + user_id: str = Header(None), # Extract user_id from header, default to None if not present ): """ Attach a data source to an existing agent. """ - actor = server.get_current_user() + actor = server.get_user_or_default(user_id=user_id) source = server.ms.get_source(source_id=source_id, user_id=actor.id) assert source is not None, f"Source with id={source_id} not found." @@ -119,11 +127,12 @@ def detach_source_from_agent( source_id: str, agent_id: str = Query(..., description="The unique identifier of the agent to detach the source from."), server: "SyncServer" = Depends(get_letta_server), + user_id: str = Header(None), # Extract user_id from header, default to None if not present ) -> None: """ Detach a data source from an existing agent. """ - actor = server.get_current_user() + actor = server.get_user_or_default(user_id=user_id) return server.detach_source_from_agent(source_id=source_id, agent_id=agent_id, user_id=actor.id) @@ -134,11 +143,12 @@ def upload_file_to_source( source_id: str, background_tasks: BackgroundTasks, server: "SyncServer" = Depends(get_letta_server), + user_id: str = Header(None), # Extract user_id from header, default to None if not present ): """ Upload a file to a data source. """ - actor = server.get_current_user() + actor = server.get_user_or_default(user_id=user_id) source = server.ms.get_source(source_id=source_id, user_id=actor.id) assert source is not None, f"Source with id={source_id} not found." @@ -166,11 +176,12 @@ def upload_file_to_source( def list_passages( source_id: str, server: SyncServer = Depends(get_letta_server), + user_id: str = Header(None), # Extract user_id from header, default to None if not present ): """ List all passages associated with a data source. """ - actor = server.get_current_user() + actor = server.get_user_or_default(user_id=user_id) passages = server.list_data_source_passages(user_id=actor.id, source_id=source_id) return passages @@ -179,11 +190,12 @@ def list_passages( def list_documents( source_id: str, server: "SyncServer" = Depends(get_letta_server), + user_id: str = Header(None), # Extract user_id from header, default to None if not present ): """ List all documents associated with a data source. """ - actor = server.get_current_user() + actor = server.get_user_or_default(user_id=user_id) documents = server.list_data_source_documents(user_id=actor.id, source_id=source_id) return documents diff --git a/letta/server/rest_api/routers/v1/tools.py b/letta/server/rest_api/routers/v1/tools.py index 1a329ddd8a..9b4d58a561 100644 --- a/letta/server/rest_api/routers/v1/tools.py +++ b/letta/server/rest_api/routers/v1/tools.py @@ -1,6 +1,6 @@ from typing import List -from fastapi import APIRouter, Body, Depends, HTTPException +from fastapi import APIRouter, Body, Depends, Header, HTTPException from letta.schemas.tool import Tool, ToolCreate, ToolUpdate from letta.server.rest_api.utils import get_letta_server @@ -13,11 +13,12 @@ def delete_tool( tool_id: str, server: SyncServer = Depends(get_letta_server), + user_id: str = Header(None), # Extract user_id from header, default to None if not present ): """ Delete a tool by name """ - # actor = server.get_current_user() + # actor = server.get_user_or_default(user_id=user_id) server.delete_tool(tool_id=tool_id) @@ -42,11 +43,12 @@ def get_tool( def get_tool_id( tool_name: str, server: SyncServer = Depends(get_letta_server), + user_id: str = Header(None), # Extract user_id from header, default to None if not present ): """ Get a tool ID by name """ - actor = server.get_current_user() + actor = server.get_user_or_default(user_id=user_id) tool_id = server.get_tool_id(tool_name, user_id=actor.id) if tool_id is None: @@ -58,11 +60,12 @@ def get_tool_id( @router.get("/", response_model=List[Tool], operation_id="list_tools") def list_all_tools( server: SyncServer = Depends(get_letta_server), + user_id: str = Header(None), # Extract user_id from header, default to None if not present ): """ Get a list of all tools available to agents created by a user """ - actor = server.get_current_user() + actor = server.get_user_or_default(user_id=user_id) actor.id # TODO: add back when user-specific @@ -75,11 +78,12 @@ def create_tool( tool: ToolCreate = Body(...), update: bool = False, server: SyncServer = Depends(get_letta_server), + user_id: str = Header(None), # Extract user_id from header, default to None if not present ): """ Create a new tool """ - actor = server.get_current_user() + actor = server.get_user_or_default(user_id=user_id) return server.create_tool( request=tool, @@ -94,10 +98,11 @@ def update_tool( tool_id: str, request: ToolUpdate = Body(...), server: SyncServer = Depends(get_letta_server), + user_id: str = Header(None), # Extract user_id from header, default to None if not present ): """ Update an existing tool """ assert tool_id == request.id, "Tool ID in path must match tool ID in request body" - server.get_current_user() + # actor = server.get_user_or_default(user_id=user_id) return server.update_tool(request) diff --git a/letta/server/server.py b/letta/server/server.py index 39547ba670..61b68907e8 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -1089,7 +1089,11 @@ def get_agent(self, user_id: str, agent_id: str, agent_name: Optional[str] = Non def get_user(self, user_id: str) -> User: """Get the user""" - return self.ms.get_user(user_id=user_id) + user = self.ms.get_user(user_id=user_id) + if user is None: + raise ValueError(f"User with user_id {user_id} does not exist") + else: + return user def get_agent_memory(self, agent_id: str) -> Memory: """Return the memory of an agent (core memory)""" @@ -1929,20 +1933,6 @@ def retry_agent_message(self, agent_id: str) -> List[Message]: letta_agent = self._get_or_load_agent(agent_id=agent_id) return letta_agent.retry_message() - def set_current_user(self, user_id: Optional[str]): - """Very hacky way to set the current user for the server, to be replaced once server becomes stateless - - NOTE: clearly not thread-safe, only exists to provide basic user_id support for REST API for now - """ - - # Make sure the user_id actually exists - if user_id is not None: - user_obj = self.get_user(user_id) - if not user_obj: - raise ValueError(f"User with id {user_id} not found") - - self._current_user = user_id - def get_default_user(self) -> User: from letta.constants import ( @@ -1959,8 +1949,9 @@ def get_default_user(self) -> User: self.ms.create_organization(org) # check if default user exists - default_user = self.get_user(DEFAULT_USER_ID) - if not default_user: + try: + self.get_user(DEFAULT_USER_ID) + except ValueError: user = User(name=DEFAULT_USER_NAME, org_id=DEFAULT_ORG_ID, id=DEFAULT_USER_ID) self.ms.create_user(user) @@ -1971,23 +1962,12 @@ def get_default_user(self) -> User: # check if default org exists return self.get_user(DEFAULT_USER_ID) - # TODO(ethan) wire back to real method in future ORM PR - def get_current_user(self) -> User: - """Returns the currently authed user. - - Since server is the core gateway this needs to pass through server as the - first touchpoint. - """ - - # Check if _current_user is set and if it's non-null: - if hasattr(self, "_current_user") and self._current_user is not None: - current_user = self.get_user(self._current_user) - if not current_user: - warnings.warn(f"Provided user '{self._current_user}' not found, using default user") - else: - return current_user - - return self.get_default_user() + def get_user_or_default(self, user_id: Optional[str]) -> User: + """Get the user object for user_id if it exists, otherwise return the default user object""" + if user_id is None: + return self.get_default_user() + else: + return self.get_user(user_id=user_id) def list_models(self) -> List[LLMConfig]: """List available models"""