Skip to content

Commit

Permalink
Implement better non-collaborative identity (#114)
Browse files Browse the repository at this point in the history
* better non-collaborative identity

* capitalize initial in avatar

* prefer getpass.getuser()

* edit client ID comment docs
  • Loading branch information
dlqqq authored Apr 24, 2023
1 parent 47fb463 commit 90849d1
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 13 deletions.
32 changes: 24 additions & 8 deletions packages/jupyter-ai/jupyter_ai/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import tornado
import uuid
import time
import getpass

from tornado.web import HTTPError
from pydantic import ValidationError
Expand All @@ -15,7 +16,7 @@
from jupyter_server.utils import ensure_async

from .task_manager import TaskManager
from .models import ChatHistory, PromptRequest, ChatRequest, ChatMessage, Message, AgentChatMessage, HumanChatMessage, ConnectionMessage, ChatClient
from .models import ChatHistory, PromptRequest, ChatRequest, ChatMessage, Message, AgentChatMessage, HumanChatMessage, ConnectionMessage, ChatClient, ChatUser


class APIHandler(BaseAPIHandler):
Expand Down Expand Up @@ -157,24 +158,39 @@ async def get(self, *args, **kwargs):
res = super().get(*args, **kwargs)
await res

def get_current_user(self) -> ChatUser:
"""Retrieves the current user. If collaborative mode is disabled, one
is synthesized from the login."""
collaborative = self.config.get("LabApp", {}).get("collaborative", False)

if collaborative:
return ChatUser(**asdict(self.current_user))


login = getpass.getuser()
return ChatUser(
username=self.current_user.username,
initials=login[0].capitalize(),
name=login,
display_name=login,
color=None,
avatar_url=None
)


def generate_client_id(self):
"""Generates a client ID to identify the current WS connection."""
# if collaborative mode is enabled, each client already has a UUID
# collaborative = self.config.get("LabApp", {}).get("collaborative", False)
# if collaborative:
# return self.current_user.username

# if collaborative mode is not enabled, each client is assigned a UUID
return uuid.uuid4().hex

def open(self):
"""Handles opening of a WebSocket connection. Client ID can be retrieved
from `self.client_id`."""

current_user = self.get_current_user().dict()
client_id = self.generate_client_id()

self.chat_handlers[client_id] = self
self.chat_clients[client_id] = ChatClient(**asdict(self.current_user), id=client_id)
self.chat_clients[client_id] = ChatClient(**current_user, id=client_id)
self.client_id = client_id
self.write_message(ConnectionMessage(client_id=client_id).dict())

Expand Down
12 changes: 7 additions & 5 deletions packages/jupyter-ai/jupyter_ai/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,7 @@ class PromptRequest(BaseModel):
class ChatRequest(BaseModel):
prompt: str

class ChatClient(BaseModel):
# Client ID assigned by us. Necessary because different JupyterLab clients
# on the same device (i.e. running on multiple tabs/windows) may have the
# same user ID assigned to them by IdentityProvider.
id: str
class ChatUser(BaseModel):
# User ID assigned by IdentityProvider.
username: str
initials: str
Expand All @@ -23,6 +19,12 @@ class ChatClient(BaseModel):
color: Optional[str]
avatar_url: Optional[str]

class ChatClient(ChatUser):
# A unique client ID assigned to identify different JupyterLab clients on
# the same device (i.e. running on multiple tabs/windows), which may have
# the same username assigned to them by the IdentityProvider.
id: str

class AgentChatMessage(BaseModel):
type: Literal["agent"] = "agent"
id: str
Expand Down

0 comments on commit 90849d1

Please sign in to comment.