From 9b22653455f33ef1b2b198f1194b7653de7f242d Mon Sep 17 00:00:00 2001 From: Piaoyang Cui Date: Wed, 19 Jul 2023 11:46:00 -0700 Subject: [PATCH] Add string-based user_id for auth --- .../eced1ae3918a_add_string_user_id.py | 34 +++++++++++++++++++ realtime_ai_character/models/interaction.py | 3 +- realtime_ai_character/websocket_routes.py | 14 +++++--- 3 files changed, 46 insertions(+), 5 deletions(-) create mode 100644 alembic/versions/eced1ae3918a_add_string_user_id.py diff --git a/alembic/versions/eced1ae3918a_add_string_user_id.py b/alembic/versions/eced1ae3918a_add_string_user_id.py new file mode 100644 index 000000000..84bb0c4c0 --- /dev/null +++ b/alembic/versions/eced1ae3918a_add_string_user_id.py @@ -0,0 +1,34 @@ +"""Add string user ID + +Revision ID: eced1ae3918a +Revises: 3821f7adaca9 +Create Date: 2023-07-19 11:02:52.002939 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'eced1ae3918a' +down_revision = '3821f7adaca9' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.add_column('interactions', sa.Column('user_id', sa.String(50), nullable=True)) + + # Populate the new column with the old column's data + op.execute(""" + UPDATE interactions + SET user_id = CAST(client_id AS TEXT) + """) + + # TODO: make the user_id column non-nullable after prod migration. + # Skip for now given production servers are distributed. Note this is not + # relevant if you deploy locally. + + +def downgrade() -> None: + op.drop_column('interactions', 'user_id') diff --git a/realtime_ai_character/models/interaction.py b/realtime_ai_character/models/interaction.py index 4213703e4..fd4351fa4 100644 --- a/realtime_ai_character/models/interaction.py +++ b/realtime_ai_character/models/interaction.py @@ -7,7 +7,8 @@ class Interaction(Base): __tablename__ = "interactions" id = Column(Integer, primary_key=True, index=True, nullable=False) - client_id = Column(Integer) + client_id = Column(Integer) # deprecated, use user_id instead + user_id = Column(String(50)) session_id = Column(String(50)) client_message = Column(String) # deprecated, use client_message_unicode instead server_message = Column(String) # deprecated, use server_message_unicode instead diff --git a/realtime_ai_character/websocket_routes.py b/realtime_ai_character/websocket_routes.py index 8d52ac40d..23ff32a20 100644 --- a/realtime_ai_character/websocket_routes.py +++ b/realtime_ai_character/websocket_routes.py @@ -42,6 +42,8 @@ async def websocket_endpoint( if os.getenv('USE_AUTH', '') and api_key != os.getenv('AUTH_API_KEY'): await websocket.close(code=1008, reason="Unauthorized") return + # TODO: replace client_id with user_id completely. + user_id = str(client_id) llm = get_llm(model=llm_model) await manager.connect(websocket) try: @@ -52,7 +54,7 @@ async def websocket_endpoint( except WebSocketDisconnect: await manager.disconnect(websocket) - await manager.broadcast_message(f"Client #{client_id} left the chat") + await manager.broadcast_message(f"User #{user_id} left the chat") async def handle_receive( @@ -65,6 +67,8 @@ async def handle_receive( text_to_speech: TextToSpeech): try: conversation_history = ConversationHistory() + # TODO: clean up client_id once migration is done. + user_id = str(client_id) session_id = str(uuid.uuid4().hex) # 0. Receive client platform info (web, mobile, terminal) @@ -72,7 +76,7 @@ async def handle_receive( if data['type'] != 'websocket.receive': raise WebSocketDisconnect('disconnected') platform = data['text'] - logger.info(f"Client #{client_id}:{platform} connected to server with " + logger.info(f"User #{user_id}:{platform} connected to server with " f"session_id {session_id}") # 1. User selected a character @@ -102,7 +106,7 @@ async def handle_receive( conversation_history.system_prompt = character.llm_system_prompt user_input_template = character.llm_user_prompt logger.info( - f"Client #{client_id} selected character: {character.name}") + f"User #{user_id} selected character: {character.name}") tts_event = asyncio.Event() tts_task = None @@ -167,6 +171,7 @@ async def stop_audio(): # 4. Persist interaction in the database Interaction( client_id=client_id, + user_id=user_id, session_id=session_id, client_message_unicode=msg_data, server_message_unicode=response, @@ -203,6 +208,7 @@ async def tts_task_done_call_back(response): # Persist interaction in the database Interaction( client_id=client_id, + user_id=user_id, session_id=session_id, client_message_unicode=transcript, server_message_unicode=response, @@ -223,6 +229,6 @@ async def tts_task_done_call_back(response): ) except WebSocketDisconnect: - logger.info(f"Client #{client_id} closed the connection") + logger.info(f"User #{user_id} closed the connection") await manager.disconnect(websocket) return