From 90849d1be619b3bf3e86dc34d51c1bdca076270b Mon Sep 17 00:00:00 2001 From: david qiu Date: Mon, 24 Apr 2023 15:28:34 -0700 Subject: [PATCH] Implement better non-collaborative identity (#114) * better non-collaborative identity * capitalize initial in avatar * prefer getpass.getuser() * edit client ID comment docs --- packages/jupyter-ai/jupyter_ai/handlers.py | 32 ++++++++++++++++------ packages/jupyter-ai/jupyter_ai/models.py | 12 ++++---- 2 files changed, 31 insertions(+), 13 deletions(-) diff --git a/packages/jupyter-ai/jupyter_ai/handlers.py b/packages/jupyter-ai/jupyter_ai/handlers.py index fa2ad1dc6..0bcfd8a62 100644 --- a/packages/jupyter-ai/jupyter_ai/handlers.py +++ b/packages/jupyter-ai/jupyter_ai/handlers.py @@ -5,6 +5,7 @@ import tornado import uuid import time +import getpass from tornado.web import HTTPError from pydantic import ValidationError @@ -15,7 +16,7 @@ from jupyter_server.utils import ensure_async from .task_manager import TaskManager -from .models import ChatHistory, PromptRequest, ChatRequest, ChatMessage, Message, AgentChatMessage, HumanChatMessage, ConnectionMessage, ChatClient +from .models import ChatHistory, PromptRequest, ChatRequest, ChatMessage, Message, AgentChatMessage, HumanChatMessage, ConnectionMessage, ChatClient, ChatUser class APIHandler(BaseAPIHandler): @@ -157,24 +158,39 @@ async def get(self, *args, **kwargs): res = super().get(*args, **kwargs) await res + def get_current_user(self) -> ChatUser: + """Retrieves the current user. If collaborative mode is disabled, one + is synthesized from the login.""" + collaborative = self.config.get("LabApp", {}).get("collaborative", False) + + if collaborative: + return ChatUser(**asdict(self.current_user)) + + + login = getpass.getuser() + return ChatUser( + username=self.current_user.username, + initials=login[0].capitalize(), + name=login, + display_name=login, + color=None, + avatar_url=None + ) + + def generate_client_id(self): """Generates a client ID to identify the current WS connection.""" - # if collaborative mode is enabled, each client already has a UUID - # collaborative = self.config.get("LabApp", {}).get("collaborative", False) - # if collaborative: - # return self.current_user.username - - # if collaborative mode is not enabled, each client is assigned a UUID return uuid.uuid4().hex def open(self): """Handles opening of a WebSocket connection. Client ID can be retrieved from `self.client_id`.""" + current_user = self.get_current_user().dict() client_id = self.generate_client_id() self.chat_handlers[client_id] = self - self.chat_clients[client_id] = ChatClient(**asdict(self.current_user), id=client_id) + self.chat_clients[client_id] = ChatClient(**current_user, id=client_id) self.client_id = client_id self.write_message(ConnectionMessage(client_id=client_id).dict()) diff --git a/packages/jupyter-ai/jupyter_ai/models.py b/packages/jupyter-ai/jupyter_ai/models.py index 0f6bf2b32..7da9aa47f 100644 --- a/packages/jupyter-ai/jupyter_ai/models.py +++ b/packages/jupyter-ai/jupyter_ai/models.py @@ -10,11 +10,7 @@ class PromptRequest(BaseModel): class ChatRequest(BaseModel): prompt: str -class ChatClient(BaseModel): - # Client ID assigned by us. Necessary because different JupyterLab clients - # on the same device (i.e. running on multiple tabs/windows) may have the - # same user ID assigned to them by IdentityProvider. - id: str +class ChatUser(BaseModel): # User ID assigned by IdentityProvider. username: str initials: str @@ -23,6 +19,12 @@ class ChatClient(BaseModel): color: Optional[str] avatar_url: Optional[str] +class ChatClient(ChatUser): + # A unique client ID assigned to identify different JupyterLab clients on + # the same device (i.e. running on multiple tabs/windows), which may have + # the same username assigned to them by the IdentityProvider. + id: str + class AgentChatMessage(BaseModel): type: Literal["agent"] = "agent" id: str