Skip to content

Commit

Permalink
Add string-based user_id for auth
Browse files Browse the repository at this point in the history
  • Loading branch information
pycui committed Jul 19, 2023
1 parent 83fc65c commit 9b22653
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 5 deletions.
34 changes: 34 additions & 0 deletions alembic/versions/eced1ae3918a_add_string_user_id.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""Add string user ID
Revision ID: eced1ae3918a
Revises: 3821f7adaca9
Create Date: 2023-07-19 11:02:52.002939
"""
from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = 'eced1ae3918a'
down_revision = '3821f7adaca9'
branch_labels = None
depends_on = None


def upgrade() -> None:
op.add_column('interactions', sa.Column('user_id', sa.String(50), nullable=True))

# Populate the new column with the old column's data
op.execute("""
UPDATE interactions
SET user_id = CAST(client_id AS TEXT)
""")

# TODO: make the user_id column non-nullable after prod migration.
# Skip for now given production servers are distributed. Note this is not
# relevant if you deploy locally.


def downgrade() -> None:
op.drop_column('interactions', 'user_id')
3 changes: 2 additions & 1 deletion realtime_ai_character/models/interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ class Interaction(Base):
__tablename__ = "interactions"

id = Column(Integer, primary_key=True, index=True, nullable=False)
client_id = Column(Integer)
client_id = Column(Integer) # deprecated, use user_id instead
user_id = Column(String(50))
session_id = Column(String(50))
client_message = Column(String) # deprecated, use client_message_unicode instead
server_message = Column(String) # deprecated, use server_message_unicode instead
Expand Down
14 changes: 10 additions & 4 deletions realtime_ai_character/websocket_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ async def websocket_endpoint(
if os.getenv('USE_AUTH', '') and api_key != os.getenv('AUTH_API_KEY'):
await websocket.close(code=1008, reason="Unauthorized")
return
# TODO: replace client_id with user_id completely.
user_id = str(client_id)
llm = get_llm(model=llm_model)
await manager.connect(websocket)
try:
Expand All @@ -52,7 +54,7 @@ async def websocket_endpoint(

except WebSocketDisconnect:
await manager.disconnect(websocket)
await manager.broadcast_message(f"Client #{client_id} left the chat")
await manager.broadcast_message(f"User #{user_id} left the chat")


async def handle_receive(
Expand All @@ -65,14 +67,16 @@ async def handle_receive(
text_to_speech: TextToSpeech):
try:
conversation_history = ConversationHistory()
# TODO: clean up client_id once migration is done.
user_id = str(client_id)
session_id = str(uuid.uuid4().hex)

# 0. Receive client platform info (web, mobile, terminal)
data = await websocket.receive()
if data['type'] != 'websocket.receive':
raise WebSocketDisconnect('disconnected')
platform = data['text']
logger.info(f"Client #{client_id}:{platform} connected to server with "
logger.info(f"User #{user_id}:{platform} connected to server with "
f"session_id {session_id}")

# 1. User selected a character
Expand Down Expand Up @@ -102,7 +106,7 @@ async def handle_receive(
conversation_history.system_prompt = character.llm_system_prompt
user_input_template = character.llm_user_prompt
logger.info(
f"Client #{client_id} selected character: {character.name}")
f"User #{user_id} selected character: {character.name}")

tts_event = asyncio.Event()
tts_task = None
Expand Down Expand Up @@ -167,6 +171,7 @@ async def stop_audio():
# 4. Persist interaction in the database
Interaction(
client_id=client_id,
user_id=user_id,
session_id=session_id,
client_message_unicode=msg_data,
server_message_unicode=response,
Expand Down Expand Up @@ -203,6 +208,7 @@ async def tts_task_done_call_back(response):
# Persist interaction in the database
Interaction(
client_id=client_id,
user_id=user_id,
session_id=session_id,
client_message_unicode=transcript,
server_message_unicode=response,
Expand All @@ -223,6 +229,6 @@ async def tts_task_done_call_back(response):
)

except WebSocketDisconnect:
logger.info(f"Client #{client_id} closed the connection")
logger.info(f"User #{user_id} closed the connection")
await manager.disconnect(websocket)
return

0 comments on commit 9b22653

Please sign in to comment.