Skip to content

Commit

Permalink
Refactor socketio_utils to persona_utils
Browse files Browse the repository at this point in the history
Move `list_personas` logic to a util method
Support no current user in `list_personas` to retrieve all personas with no user filtering
  • Loading branch information
NeonDaniel committed Jan 6, 2025
1 parent 210e4a8 commit 590812a
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 33 deletions.
30 changes: 3 additions & 27 deletions chat_server/blueprints/personas.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
PersonaData,
)
from chat_server.server_utils.api_dependencies.validators import permitted_access
from chat_server.server_utils.socketio_utils import notify_personas_changed
from chat_server.server_utils.persona_utils import notify_personas_changed, list_personas as _list_personas
from utils.database_utils.mongo_utils import MongoFilter, MongoLogicalOperators
from utils.database_utils.mongo_utils.queries.wrapper import MongoDocumentsAPI

Expand All @@ -61,34 +61,10 @@
@router.get("/list")
async def list_personas(
current_user: CurrentUserData,
request_model: ListPersonasQueryModel = permitted_access(ListPersonasQueryModel),
request_model: ListPersonasQueryModel = permitted_access(ListPersonasQueryModel)
) -> JSONResponse:
"""Lists personas matching query params"""
filters = []
if request_model.llms:
filters.append(
MongoFilter(
key="supported_llms",
value=request_model.llms,
logical_operator=MongoLogicalOperators.ALL,
)
)
if request_model.user_id and request_model.user_id != "*":
filters.append(MongoFilter(key="user_id", value=request_model.user_id))
else:
user_filter = [{"user_id": None}, {"user_id": current_user.user_id}]
filters.append(
MongoFilter(value=user_filter, logical_operator=MongoLogicalOperators.OR)
)
if request_model.only_enabled:
filters.append(MongoFilter(key="enabled", value=True))
items = MongoDocumentsAPI.PERSONAS.list_items(
filters=filters, result_as_cursor=False
)
for item in items:
item["id"] = item.pop("_id")
item["enabled"] = item.get("enabled", False)
return JSONResponse(content={"items": items})
return await _list_personas(current_user, request_model)


@router.get("/get/{persona_id}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,12 @@
from typing import Optional, List
from asyncio import Lock

from chat_server.server_utils.api_dependencies import (CurrentUserModel,
ListPersonasQueryModel)
from chat_server.sio.server import sio
from starlette.responses import JSONResponse

from chat_server.server_utils.api_dependencies import ListPersonasQueryModel, CurrentUserData
from chat_server.sio.server import sio
from utils.database_utils.mongo_utils import MongoFilter, MongoLogicalOperators
from utils.database_utils.mongo_utils.queries.wrapper import MongoDocumentsAPI

_LOCK = Lock()

Expand All @@ -48,10 +50,8 @@ async def notify_personas_changed(supported_llms: Optional[List[str]] = None):
:param supported_llms: List of LLM names affected by a transaction. If None,
then updates all LLMs listed in database configuration
"""
from chat_server.blueprints.personas import list_personas
async with _LOCK:
resp = await list_personas(CurrentUserModel(_id="", nickname="",
first_name="", last_name=""),
resp = await list_personas(None,
ListPersonasQueryModel(only_enabled=True))
update_time = time()
enabled_personas = json.loads(resp.body.decode())
Expand All @@ -69,3 +69,32 @@ async def notify_personas_changed(supported_llms: Optional[List[str]] = None):
valid_personas[llm].append(persona)
sio.emit("configured_personas_changed", {"personas": valid_personas,
"update_time": update_time})


async def list_personas(current_user: CurrentUserData,
request_model: ListPersonasQueryModel) -> JSONResponse:
filters = []
if request_model.llms:
filters.append(
MongoFilter(
key="supported_llms",
value=request_model.llms,
logical_operator=MongoLogicalOperators.ALL,
)
)
if request_model.user_id and request_model.user_id != "*":
filters.append(MongoFilter(key="user_id", value=request_model.user_id))
elif current_user:
user_filter = [{"user_id": None}, {"user_id": current_user.user_id}]
filters.append(
MongoFilter(value=user_filter, logical_operator=MongoLogicalOperators.OR)
)
if request_model.only_enabled:
filters.append(MongoFilter(key="enabled", value=True))
items = MongoDocumentsAPI.PERSONAS.list_items(
filters=filters, result_as_cursor=False
)
for item in items:
item["id"] = item.pop("_id")
item["enabled"] = item.get("enabled", False)
return JSONResponse(content={"items": items})

0 comments on commit 590812a

Please sign in to comment.