diff --git a/chat_server/blueprints/personas.py b/chat_server/blueprints/personas.py index 6dd8c3d7..5ca74785 100644 --- a/chat_server/blueprints/personas.py +++ b/chat_server/blueprints/personas.py @@ -48,7 +48,7 @@ PersonaData, ) from chat_server.server_utils.api_dependencies.validators import permitted_access -from chat_server.server_utils.socketio_utils import notify_personas_changed +from chat_server.server_utils.persona_utils import notify_personas_changed, list_personas as _list_personas from utils.database_utils.mongo_utils import MongoFilter, MongoLogicalOperators from utils.database_utils.mongo_utils.queries.wrapper import MongoDocumentsAPI @@ -61,34 +61,10 @@ @router.get("/list") async def list_personas( current_user: CurrentUserData, - request_model: ListPersonasQueryModel = permitted_access(ListPersonasQueryModel), + request_model: ListPersonasQueryModel = permitted_access(ListPersonasQueryModel) ) -> JSONResponse: """Lists personas matching query params""" - filters = [] - if request_model.llms: - filters.append( - MongoFilter( - key="supported_llms", - value=request_model.llms, - logical_operator=MongoLogicalOperators.ALL, - ) - ) - if request_model.user_id and request_model.user_id != "*": - filters.append(MongoFilter(key="user_id", value=request_model.user_id)) - else: - user_filter = [{"user_id": None}, {"user_id": current_user.user_id}] - filters.append( - MongoFilter(value=user_filter, logical_operator=MongoLogicalOperators.OR) - ) - if request_model.only_enabled: - filters.append(MongoFilter(key="enabled", value=True)) - items = MongoDocumentsAPI.PERSONAS.list_items( - filters=filters, result_as_cursor=False - ) - for item in items: - item["id"] = item.pop("_id") - item["enabled"] = item.get("enabled", False) - return JSONResponse(content={"items": items}) + return await _list_personas(current_user, request_model) @router.get("/get/{persona_id}") diff --git a/chat_server/server_utils/socketio_utils.py b/chat_server/server_utils/persona_utils.py similarity index 69% rename from chat_server/server_utils/socketio_utils.py rename to chat_server/server_utils/persona_utils.py index 2c4412ab..44475402 100644 --- a/chat_server/server_utils/socketio_utils.py +++ b/chat_server/server_utils/persona_utils.py @@ -32,10 +32,12 @@ from typing import Optional, List from asyncio import Lock -from chat_server.server_utils.api_dependencies import (CurrentUserModel, - ListPersonasQueryModel) -from chat_server.sio.server import sio +from starlette.responses import JSONResponse +from chat_server.server_utils.api_dependencies import ListPersonasQueryModel, CurrentUserData +from chat_server.sio.server import sio +from utils.database_utils.mongo_utils import MongoFilter, MongoLogicalOperators +from utils.database_utils.mongo_utils.queries.wrapper import MongoDocumentsAPI _LOCK = Lock() @@ -48,10 +50,8 @@ async def notify_personas_changed(supported_llms: Optional[List[str]] = None): :param supported_llms: List of LLM names affected by a transaction. If None, then updates all LLMs listed in database configuration """ - from chat_server.blueprints.personas import list_personas async with _LOCK: - resp = await list_personas(CurrentUserModel(_id="", nickname="", - first_name="", last_name=""), + resp = await list_personas(None, ListPersonasQueryModel(only_enabled=True)) update_time = time() enabled_personas = json.loads(resp.body.decode()) @@ -69,3 +69,32 @@ async def notify_personas_changed(supported_llms: Optional[List[str]] = None): valid_personas[llm].append(persona) sio.emit("configured_personas_changed", {"personas": valid_personas, "update_time": update_time}) + + +async def list_personas(current_user: CurrentUserData, + request_model: ListPersonasQueryModel) -> JSONResponse: + filters = [] + if request_model.llms: + filters.append( + MongoFilter( + key="supported_llms", + value=request_model.llms, + logical_operator=MongoLogicalOperators.ALL, + ) + ) + if request_model.user_id and request_model.user_id != "*": + filters.append(MongoFilter(key="user_id", value=request_model.user_id)) + elif current_user: + user_filter = [{"user_id": None}, {"user_id": current_user.user_id}] + filters.append( + MongoFilter(value=user_filter, logical_operator=MongoLogicalOperators.OR) + ) + if request_model.only_enabled: + filters.append(MongoFilter(key="enabled", value=True)) + items = MongoDocumentsAPI.PERSONAS.list_items( + filters=filters, result_as_cursor=False + ) + for item in items: + item["id"] = item.pop("_id") + item["enabled"] = item.get("enabled", False) + return JSONResponse(content={"items": items}) \ No newline at end of file