Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unify message clearing & broadcast logic #1038

Merged
merged 8 commits into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 23 additions & 22 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
ChatMessage,
ClosePendingMessage,
HumanChatMessage,
Message,
PendingMessage,
)
from jupyter_ai_magics import Persona
Expand Down Expand Up @@ -261,6 +262,24 @@ async def _default_handle_exc(self, e: Exception, message: HumanChatMessage):
)
self.reply(response, message)

def broadcast_message(self, message: Message):
"""
Broadcasts a message to all WebSocket connections. If there are no
WebSocket connections, this method directly appends to
`self.chat_history`.
"""
broadcast = False
for websocket in self._root_chat_handlers.values():
if not websocket:
continue

websocket.broadcast_message(message)
broadcast = True
break

if not broadcast:
self._chat_history.append(message)

def reply(self, response: str, human_msg: Optional[HumanChatMessage] = None):
"""
Sends an agent message, usually in response to a received
Expand All @@ -274,12 +293,7 @@ def reply(self, response: str, human_msg: Optional[HumanChatMessage] = None):
persona=self.persona,
)

for handler in self._root_chat_handlers.values():
if not handler:
continue

handler.broadcast_message(agent_msg)
break
self.broadcast_message(agent_msg)

@property
def persona(self):
Expand Down Expand Up @@ -308,12 +322,7 @@ def start_pending(
ellipsis=ellipsis,
)

for handler in self._root_chat_handlers.values():
if not handler:
continue

handler.broadcast_message(pending_msg)
break
self.broadcast_message(pending_msg)
return pending_msg

def close_pending(self, pending_msg: PendingMessage):
Expand All @@ -327,13 +336,7 @@ def close_pending(self, pending_msg: PendingMessage):
id=pending_msg.id,
)

for handler in self._root_chat_handlers.values():
if not handler:
continue

handler.broadcast_message(close_pending_msg)
break

self.broadcast_message(close_pending_msg)
pending_msg.closed = True

@contextlib.contextmanager
Expand Down Expand Up @@ -464,6 +467,4 @@ def send_help_message(self, human_msg: Optional[HumanChatMessage] = None) -> Non
persona=self.persona,
)

self._chat_history.append(help_message)
for websocket in self._root_chat_handlers.values():
websocket.write_message(help_message.json())
self.broadcast_message(help_message)
6 changes: 3 additions & 3 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/clear.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from jupyter_ai.models import ClearMessage
from jupyter_ai.models import ClearRequest

from .base import BaseChatHandler, SlashCommandRoutingType

Expand All @@ -17,10 +17,10 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

async def process_message(self, _):
# Clear chat
# Clear chat by triggering `RootChatHandler.on_clear_request()`.
for handler in self._root_chat_handlers.values():
if not handler:
continue

handler.broadcast_message(ClearMessage())
handler.on_clear_request(ClearRequest())
break
111 changes: 50 additions & 61 deletions packages/jupyter-ai/jupyter_ai/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
ListSlashCommandsResponse,
Message,
PendingMessage,
StopMessage,
StopRequest,
UpdateConfigRequest,
)
Expand Down Expand Up @@ -122,6 +121,13 @@ def loop(self) -> AbstractEventLoop:
def pending_messages(self) -> List[PendingMessage]:
return self.settings["pending_messages"]

@property
def cleared_message_ids(self) -> Set[str]:
"""Set of `HumanChatMessage.id` that were cleared via `ClearRequest`."""
if "cleared_message_ids" not in self.settings:
self.settings["cleared_message_ids"] = set()
return self.settings["cleared_message_ids"]

dlqqq marked this conversation as resolved.
Show resolved Hide resolved
@pending_messages.setter
def pending_messages(self, new_pending_messages):
self.settings["pending_messages"] = new_pending_messages
Expand Down Expand Up @@ -227,12 +233,9 @@ def broadcast_message(self, message: Message):
# do not broadcast agent messages that are replying to cleared human message
if (
isinstance(message, (AgentChatMessage, AgentStreamMessage))
and message.reply_to
and message.reply_to in self.cleared_message_ids
):
if message.reply_to not in [
m.id for m in self.chat_history if isinstance(m, HumanChatMessage)
]:
return
return

self.log.debug("Broadcasting message: %s to all clients...", message)
client_ids = self.root_chat_handlers.keys()
Expand Down Expand Up @@ -269,14 +272,6 @@ def broadcast_message(self, message: Message):
self.pending_messages = list(
filter(lambda m: m.id != message.id, self.pending_messages)
)
elif isinstance(message, ClearMessage):
if message.targets:
self._clear_chat_history_at(message.targets)
else:
self.chat_history.clear()
self.pending_messages.clear()
self.llm_chat_memory.clear()
self.settings["jai_chat_handlers"]["default"].send_help_message()

async def on_message(self, message):
self.log.debug("Message received: %s", message)
Expand All @@ -294,22 +289,7 @@ async def on_message(self, message):
return

if isinstance(request, ClearRequest):
if not request.target:
targets = None
elif request.after:
target_msg = None
for msg in self.chat_history:
if msg.id == request.target:
target_msg = msg
if target_msg:
targets = [
msg.id
for msg in self.chat_history
if msg.time >= target_msg.time and msg.type == "human"
]
else:
targets = [request.target]
self.broadcast_message(ClearMessage(targets=targets))
self.on_clear_request(request)
return

if isinstance(request, StopRequest):
Expand Down Expand Up @@ -340,6 +320,46 @@ async def on_message(self, message):
# as a distinct concurrent task.
self.loop.create_task(self._route(chat_message))

def on_clear_request(self, request: ClearRequest):
target = request.target

# if no target, clear all messages
if not target:
for msg in self.chat_history:
if msg.type == "human":
self.cleared_message_ids.add(msg.id)

self.chat_history.clear()
self.pending_messages.clear()
self.llm_chat_memory.clear()
self.broadcast_message(ClearMessage())
self.settings["jai_chat_handlers"]["default"].send_help_message()
return

# otherwise, clear a single message
self.cleared_message_ids.add(target)
for msg in self.chat_history[::-1]:
# interrupt the single message
if msg.type == "agent-stream" and getattr(msg, "reply_to", None) == target:
try:
self.message_interrupted[msg.id].set()
except KeyError:
# do nothing if the message was already interrupted
# or stream got completed (thread-safe way!)
pass
break

self.chat_history[:] = [
msg
for msg in self.chat_history
if msg.id != target and getattr(msg, "reply_to", None) != target
]
self.pending_messages[:] = [
msg for msg in self.pending_messages if msg.reply_to != target
]
self.llm_chat_memory.clear([target])
self.broadcast_message(ClearMessage(targets=[target]))

def on_stop_request(self):
# set of message IDs that were submitted by this user, determined by the
# username associated with this WebSocket connection.
Expand Down Expand Up @@ -390,37 +410,6 @@ async def _route(self, message):
command_readable = "Default" if command == "default" else command
self.log.info(f"{command_readable} chat handler resolved in {latency_ms} ms.")

def _clear_chat_history_at(self, msg_ids: List[str]):
"""
Clears conversation exchanges associated with list of human message IDs.
"""
messages_to_interrupt = [
msg
for msg in self.chat_history
if (
msg.type == "agent-stream"
and getattr(msg, "reply_to", None) in msg_ids
and not msg.complete
)
]
for msg in messages_to_interrupt:
try:
self.message_interrupted[msg.id].set()
except KeyError:
# do nothing if the message was already interrupted
# or stream got completed (thread-safe way!)
pass

self.chat_history[:] = [
msg
for msg in self.chat_history
if msg.id not in msg_ids and getattr(msg, "reply_to", None) not in msg_ids
]
self.pending_messages[:] = [
msg for msg in self.pending_messages if msg.reply_to not in msg_ids
]
self.llm_chat_memory.clear(msg_ids)

def on_close(self):
self.log.debug("Disconnecting client with user %s", self.client_id)

Expand Down
14 changes: 1 addition & 13 deletions packages/jupyter-ai/jupyter_ai/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,18 +52,13 @@ class StopRequest(BaseModel):


class ClearRequest(BaseModel):
type: Literal["clear"]
type: Literal["clear"] = "clear"
target: Optional[str]
"""
Message ID of the HumanChatMessage to delete an exchange at.
If not provided, this requests the backend to clear all messages.
"""

after: Optional[bool]
"""
Whether to clear target and all subsequent exchanges.
"""


class ChatUser(BaseModel):
# User ID assigned by IdentityProvider.
Expand Down Expand Up @@ -148,13 +143,6 @@ class HumanChatMessage(BaseModel):
client: ChatClient


class StopMessage(BaseModel):
"""Message broadcast to clients after receiving a request to stop stop streaming or generating response"""

type: Literal["stop"] = "stop"
target: str


class ClearMessage(BaseModel):
type: Literal["clear"] = "clear"
targets: Optional[List[str]] = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@ type DeleteButtonProps = {
export function ChatMessageDelete(props: DeleteButtonProps): JSX.Element {
const request: AiService.ClearRequest = {
type: 'clear',
target: props.message.id,
after: false
dlqqq marked this conversation as resolved.
Show resolved Hide resolved
target: props.message.id
};
return (
<TooltippedIconButton
Expand Down
1 change: 0 additions & 1 deletion packages/jupyter-ai/src/handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ export namespace AiService {
export type ClearRequest = {
type: 'clear';
target?: string;
after?: boolean;
};

export type StopRequest = {
Expand Down
Loading